From f54b07121f81a56145fb118a2e18841fc135717d Mon Sep 17 00:00:00 2001 From: Ethan Nicholas Date: Thu, 19 Jan 2017 10:44:45 -0500 Subject: Added constant propagation and better variable liveness tracking to skslc. This allows skslc to track the values of variables with constant values across multiple statements and replace variable references with constant values where appropriate. The improved liveness tracking allows skslc to realize that a variable is no longer alive if all references to it have been replaced. It is not yet doing much with this information; better dead code elimination is coming in a followup change. BUG=skia: Change-Id: I6bf267d478b769caf0063ac3597dc16bbe618cb4 Reviewed-on: https://skia-review.googlesource.com/7033 Commit-Queue: Ethan Nicholas Reviewed-by: Greg Daniel --- src/sksl/SkSLCFGGenerator.cpp | 159 ++++++++++++++++------------- src/sksl/SkSLCFGGenerator.h | 19 +++- src/sksl/SkSLCompiler.cpp | 119 +++++++++++++++------ src/sksl/SkSLCompiler.h | 7 +- src/sksl/SkSLIRGenerator.cpp | 78 +++++--------- src/sksl/SkSLIRGenerator.h | 18 ++-- src/sksl/SkSLSPIRVCodeGenerator.cpp | 2 +- src/sksl/SkSLToken.h | 22 ++++ src/sksl/ir/SkSLBinaryExpression.h | 18 +++- src/sksl/ir/SkSLBlock.h | 8 +- src/sksl/ir/SkSLConstructor.h | 19 +++- src/sksl/ir/SkSLDoStatement.h | 2 +- src/sksl/ir/SkSLExpression.h | 25 ++++- src/sksl/ir/SkSLExpressionStatement.h | 2 +- src/sksl/ir/SkSLFieldAccess.h | 2 +- src/sksl/ir/SkSLForStatement.h | 4 +- src/sksl/ir/SkSLFunctionCall.h | 2 +- src/sksl/ir/SkSLIfStatement.h | 2 +- src/sksl/ir/SkSLIndexExpression.h | 4 +- src/sksl/ir/SkSLPostfixExpression.h | 2 +- src/sksl/ir/SkSLPrefixExpression.h | 2 +- src/sksl/ir/SkSLProgram.h | 6 +- src/sksl/ir/SkSLReturnStatement.h | 2 +- src/sksl/ir/SkSLSwizzle.h | 2 +- src/sksl/ir/SkSLTernaryExpression.h | 6 +- src/sksl/ir/SkSLVarDeclarations.h | 2 +- src/sksl/ir/SkSLVarDeclarationsStatement.h | 4 +- src/sksl/ir/SkSLVariable.h | 12 ++- src/sksl/ir/SkSLVariableReference.h | 71 ++++++++++++- src/sksl/ir/SkSLWhileStatement.h | 2 +- 30 files changed, 412 insertions(+), 211 deletions(-) (limited to 'src') diff --git a/src/sksl/SkSLCFGGenerator.cpp b/src/sksl/SkSLCFGGenerator.cpp index 964a8dc84a..31bace9fb7 100644 --- a/src/sksl/SkSLCFGGenerator.cpp +++ b/src/sksl/SkSLCFGGenerator.cpp @@ -54,8 +54,8 @@ void CFG::dump() { printf("Block %d\n-------\nBefore: ", (int) i); const char* separator = ""; for (auto iter = fBlocks[i].fBefore.begin(); iter != fBlocks[i].fBefore.end(); iter++) { - printf("%s%s = %s", separator, iter->first->description().c_str(), - iter->second ? iter->second->description().c_str() : ""); + printf("%s%s = %s", separator, iter->first->description().c_str(), + *iter->second ? (*iter->second)->description().c_str() : ""); separator = ", "; } printf("\nEntrances: "); @@ -66,7 +66,10 @@ void CFG::dump() { } printf("\n"); for (size_t j = 0; j < fBlocks[i].fNodes.size(); j++) { - printf("Node %d: %s\n", (int) j, fBlocks[i].fNodes[j].fNode->description().c_str()); + BasicBlock::Node& n = fBlocks[i].fNodes[j]; + printf("Node %d: %s\n", (int) j, n.fKind == BasicBlock::Node::kExpression_Kind + ? (*n.fExpression)->description().c_str() + : n.fStatement->description().c_str()); } printf("Exits: "); separator = ""; @@ -78,96 +81,109 @@ void CFG::dump() { } } -void CFGGenerator::addExpression(CFG& cfg, const Expression* e) { - switch (e->fKind) { +void CFGGenerator::addExpression(CFG& cfg, std::unique_ptr* e, bool constantPropagate) { + ASSERT(e); + switch ((*e)->fKind) { case Expression::kBinary_Kind: { - const BinaryExpression* b = (const BinaryExpression*) e; + BinaryExpression* b = (BinaryExpression*) e->get(); switch (b->fOperator) { case Token::LOGICALAND: // fall through case Token::LOGICALOR: { // this isn't as precise as it could be -- we don't bother to track that if we // early exit from a logical and/or, we know which branch of an 'if' we're going // to hit -- but it won't make much difference in practice. - this->addExpression(cfg, b->fLeft.get()); + this->addExpression(cfg, &b->fLeft, constantPropagate); BlockId start = cfg.fCurrent; cfg.newBlock(); - this->addExpression(cfg, b->fRight.get()); + this->addExpression(cfg, &b->fRight, constantPropagate); cfg.newBlock(); cfg.addExit(start, cfg.fCurrent); break; } case Token::EQ: { - this->addExpression(cfg, b->fRight.get()); - this->addLValue(cfg, b->fLeft.get()); - cfg.fBlocks[cfg.fCurrent].fNodes.push_back({ - BasicBlock::Node::kExpression_Kind, - b + this->addExpression(cfg, &b->fRight, constantPropagate); + this->addLValue(cfg, &b->fLeft); + cfg.fBlocks[cfg.fCurrent].fNodes.push_back({ + BasicBlock::Node::kExpression_Kind, + constantPropagate, + e, + nullptr }); break; } default: - this->addExpression(cfg, b->fLeft.get()); - this->addExpression(cfg, b->fRight.get()); - cfg.fBlocks[cfg.fCurrent].fNodes.push_back({ - BasicBlock::Node::kExpression_Kind, - b + this->addExpression(cfg, &b->fLeft, !Token::IsAssignment(b->fOperator)); + this->addExpression(cfg, &b->fRight, constantPropagate); + cfg.fBlocks[cfg.fCurrent].fNodes.push_back({ + BasicBlock::Node::kExpression_Kind, + constantPropagate, + e, + nullptr }); } break; } case Expression::kConstructor_Kind: { - const Constructor* c = (const Constructor*) e; - for (const auto& arg : c->fArguments) { - this->addExpression(cfg, arg.get()); + Constructor* c = (Constructor*) e->get(); + for (auto& arg : c->fArguments) { + this->addExpression(cfg, &arg, constantPropagate); } - cfg.fBlocks[cfg.fCurrent].fNodes.push_back({ BasicBlock::Node::kExpression_Kind, c }); + cfg.fBlocks[cfg.fCurrent].fNodes.push_back({ BasicBlock::Node::kExpression_Kind, + constantPropagate, e, nullptr }); break; } case Expression::kFunctionCall_Kind: { - const FunctionCall* c = (const FunctionCall*) e; - for (const auto& arg : c->fArguments) { - this->addExpression(cfg, arg.get()); + FunctionCall* c = (FunctionCall*) e->get(); + for (auto& arg : c->fArguments) { + this->addExpression(cfg, &arg, constantPropagate); } - cfg.fBlocks[cfg.fCurrent].fNodes.push_back({ BasicBlock::Node::kExpression_Kind, c }); + cfg.fBlocks[cfg.fCurrent].fNodes.push_back({ BasicBlock::Node::kExpression_Kind, + constantPropagate, e, nullptr }); break; } case Expression::kFieldAccess_Kind: - this->addExpression(cfg, ((const FieldAccess*) e)->fBase.get()); - cfg.fBlocks[cfg.fCurrent].fNodes.push_back({ BasicBlock::Node::kExpression_Kind, e }); + this->addExpression(cfg, &((FieldAccess*) e->get())->fBase, constantPropagate); + cfg.fBlocks[cfg.fCurrent].fNodes.push_back({ BasicBlock::Node::kExpression_Kind, + constantPropagate, e, nullptr }); break; case Expression::kIndex_Kind: - this->addExpression(cfg, ((const IndexExpression*) e)->fBase.get()); - this->addExpression(cfg, ((const IndexExpression*) e)->fIndex.get()); - cfg.fBlocks[cfg.fCurrent].fNodes.push_back({ BasicBlock::Node::kExpression_Kind, e }); + this->addExpression(cfg, &((IndexExpression*) e->get())->fBase, constantPropagate); + this->addExpression(cfg, &((IndexExpression*) e->get())->fIndex, constantPropagate); + cfg.fBlocks[cfg.fCurrent].fNodes.push_back({ BasicBlock::Node::kExpression_Kind, + constantPropagate, e, nullptr }); break; case Expression::kPrefix_Kind: - this->addExpression(cfg, ((const PrefixExpression*) e)->fOperand.get()); - cfg.fBlocks[cfg.fCurrent].fNodes.push_back({ BasicBlock::Node::kExpression_Kind, e }); + this->addExpression(cfg, &((PrefixExpression*) e->get())->fOperand, constantPropagate); + cfg.fBlocks[cfg.fCurrent].fNodes.push_back({ BasicBlock::Node::kExpression_Kind, + constantPropagate, e, nullptr }); break; case Expression::kPostfix_Kind: - this->addExpression(cfg, ((const PostfixExpression*) e)->fOperand.get()); - cfg.fBlocks[cfg.fCurrent].fNodes.push_back({ BasicBlock::Node::kExpression_Kind, e }); + this->addExpression(cfg, &((PostfixExpression*) e->get())->fOperand, constantPropagate); + cfg.fBlocks[cfg.fCurrent].fNodes.push_back({ BasicBlock::Node::kExpression_Kind, + constantPropagate, e, nullptr }); break; case Expression::kSwizzle_Kind: - this->addExpression(cfg, ((const Swizzle*) e)->fBase.get()); - cfg.fBlocks[cfg.fCurrent].fNodes.push_back({ BasicBlock::Node::kExpression_Kind, e }); + this->addExpression(cfg, &((Swizzle*) e->get())->fBase, constantPropagate); + cfg.fBlocks[cfg.fCurrent].fNodes.push_back({ BasicBlock::Node::kExpression_Kind, + constantPropagate, e, nullptr }); break; case Expression::kBoolLiteral_Kind: // fall through case Expression::kFloatLiteral_Kind: // fall through case Expression::kIntLiteral_Kind: // fall through case Expression::kVariableReference_Kind: - cfg.fBlocks[cfg.fCurrent].fNodes.push_back({ BasicBlock::Node::kExpression_Kind, e }); + cfg.fBlocks[cfg.fCurrent].fNodes.push_back({ BasicBlock::Node::kExpression_Kind, + constantPropagate, e, nullptr }); break; case Expression::kTernary_Kind: { - const TernaryExpression* t = (const TernaryExpression*) e; - this->addExpression(cfg, t->fTest.get()); + TernaryExpression* t = (TernaryExpression*) e->get(); + this->addExpression(cfg, &t->fTest, constantPropagate); BlockId start = cfg.fCurrent; cfg.newBlock(); - this->addExpression(cfg, t->fIfTrue.get()); + this->addExpression(cfg, &t->fIfTrue, constantPropagate); BlockId next = cfg.newBlock(); cfg.fCurrent = start; cfg.newBlock(); - this->addExpression(cfg, t->fIfFalse.get()); + this->addExpression(cfg, &t->fIfFalse, constantPropagate); cfg.addExit(cfg.fCurrent, next); cfg.fCurrent = next; break; @@ -181,17 +197,17 @@ void CFGGenerator::addExpression(CFG& cfg, const Expression* e) { } // adds expressions that are evaluated as part of resolving an lvalue -void CFGGenerator::addLValue(CFG& cfg, const Expression* e) { - switch (e->fKind) { +void CFGGenerator::addLValue(CFG& cfg, std::unique_ptr* e) { + switch ((*e)->fKind) { case Expression::kFieldAccess_Kind: - this->addLValue(cfg, ((const FieldAccess*) e)->fBase.get()); + this->addLValue(cfg, &((FieldAccess&) **e).fBase); break; case Expression::kIndex_Kind: - this->addLValue(cfg, ((const IndexExpression*) e)->fBase.get()); - this->addExpression(cfg, ((const IndexExpression*) e)->fIndex.get()); + this->addLValue(cfg, &((IndexExpression&) **e).fBase); + this->addExpression(cfg, &((IndexExpression&) **e).fIndex, true); break; case Expression::kSwizzle_Kind: - this->addLValue(cfg, ((const Swizzle*) e)->fBase.get()); + this->addLValue(cfg, &((Swizzle&) **e).fBase); break; case Expression::kVariableReference_Kind: break; @@ -210,8 +226,8 @@ void CFGGenerator::addStatement(CFG& cfg, const Statement* s) { } break; case Statement::kIf_Kind: { - const IfStatement* ifs = (const IfStatement*) s; - this->addExpression(cfg, ifs->fTest.get()); + IfStatement* ifs = (IfStatement*) s; + this->addExpression(cfg, &ifs->fTest, true); BlockId start = cfg.fCurrent; cfg.newBlock(); this->addStatement(cfg, ifs->fIfTrue.get()); @@ -228,49 +244,54 @@ void CFGGenerator::addStatement(CFG& cfg, const Statement* s) { break; } case Statement::kExpression_Kind: { - this->addExpression(cfg, ((ExpressionStatement&) *s).fExpression.get()); + this->addExpression(cfg, &((ExpressionStatement&) *s).fExpression, true); break; } case Statement::kVarDeclarations_Kind: { - const VarDeclarationsStatement& decls = ((VarDeclarationsStatement&) *s); - for (const auto& vd : decls.fDeclaration->fVars) { + VarDeclarationsStatement& decls = ((VarDeclarationsStatement&) *s); + for (auto& vd : decls.fDeclaration->fVars) { if (vd.fValue) { - this->addExpression(cfg, vd.fValue.get()); + this->addExpression(cfg, &vd.fValue, true); } } - cfg.fBlocks[cfg.fCurrent].fNodes.push_back({ BasicBlock::Node::kStatement_Kind, s }); + cfg.fBlocks[cfg.fCurrent].fNodes.push_back({ BasicBlock::Node::kStatement_Kind, false, + nullptr, s }); break; } case Statement::kDiscard_Kind: - cfg.fBlocks[cfg.fCurrent].fNodes.push_back({ BasicBlock::Node::kStatement_Kind, s }); + cfg.fBlocks[cfg.fCurrent].fNodes.push_back({ BasicBlock::Node::kStatement_Kind, false, + nullptr, s }); cfg.fCurrent = cfg.newIsolatedBlock(); break; case Statement::kReturn_Kind: { - const ReturnStatement& r = ((ReturnStatement&) *s); + ReturnStatement& r = ((ReturnStatement&) *s); if (r.fExpression) { - this->addExpression(cfg, r.fExpression.get()); + this->addExpression(cfg, &r.fExpression, true); } - cfg.fBlocks[cfg.fCurrent].fNodes.push_back({ BasicBlock::Node::kStatement_Kind, s }); + cfg.fBlocks[cfg.fCurrent].fNodes.push_back({ BasicBlock::Node::kStatement_Kind, false, + nullptr, s }); cfg.fCurrent = cfg.newIsolatedBlock(); break; } case Statement::kBreak_Kind: - cfg.fBlocks[cfg.fCurrent].fNodes.push_back({ BasicBlock::Node::kStatement_Kind, s }); + cfg.fBlocks[cfg.fCurrent].fNodes.push_back({ BasicBlock::Node::kStatement_Kind, false, + nullptr, s }); cfg.addExit(cfg.fCurrent, fLoopExits.top()); cfg.fCurrent = cfg.newIsolatedBlock(); break; case Statement::kContinue_Kind: - cfg.fBlocks[cfg.fCurrent].fNodes.push_back({ BasicBlock::Node::kStatement_Kind, s }); + cfg.fBlocks[cfg.fCurrent].fNodes.push_back({ BasicBlock::Node::kStatement_Kind, false, + nullptr, s }); cfg.addExit(cfg.fCurrent, fLoopContinues.top()); cfg.fCurrent = cfg.newIsolatedBlock(); break; case Statement::kWhile_Kind: { - const WhileStatement* w = (const WhileStatement*) s; + WhileStatement* w = (WhileStatement*) s; BlockId loopStart = cfg.newBlock(); fLoopContinues.push(loopStart); BlockId loopExit = cfg.newIsolatedBlock(); fLoopExits.push(loopExit); - this->addExpression(cfg, w->fTest.get()); + this->addExpression(cfg, &w->fTest, true); BlockId test = cfg.fCurrent; cfg.addExit(test, loopExit); cfg.newBlock(); @@ -282,13 +303,13 @@ void CFGGenerator::addStatement(CFG& cfg, const Statement* s) { break; } case Statement::kDo_Kind: { - const DoStatement* d = (const DoStatement*) s; + DoStatement* d = (DoStatement*) s; BlockId loopStart = cfg.newBlock(); fLoopContinues.push(loopStart); BlockId loopExit = cfg.newIsolatedBlock(); fLoopExits.push(loopExit); this->addStatement(cfg, d->fStatement.get()); - this->addExpression(cfg, d->fTest.get()); + this->addExpression(cfg, &d->fTest, true); cfg.addExit(cfg.fCurrent, loopExit); cfg.addExit(cfg.fCurrent, loopStart); fLoopContinues.pop(); @@ -297,7 +318,7 @@ void CFGGenerator::addStatement(CFG& cfg, const Statement* s) { break; } case Statement::kFor_Kind: { - const ForStatement* f = (const ForStatement*) s; + ForStatement* f = (ForStatement*) s; if (f->fInitializer) { this->addStatement(cfg, f->fInitializer.get()); } @@ -307,7 +328,7 @@ void CFGGenerator::addStatement(CFG& cfg, const Statement* s) { BlockId loopExit = cfg.newIsolatedBlock(); fLoopExits.push(loopExit); if (f->fTest) { - this->addExpression(cfg, f->fTest.get()); + this->addExpression(cfg, &f->fTest, true); BlockId test = cfg.fCurrent; cfg.addExit(test, loopExit); } @@ -316,9 +337,9 @@ void CFGGenerator::addStatement(CFG& cfg, const Statement* s) { cfg.addExit(cfg.fCurrent, next); cfg.fCurrent = next; if (f->fNext) { - this->addExpression(cfg, f->fNext.get()); + this->addExpression(cfg, &f->fNext, true); } - cfg.addExit(next, loopStart); + cfg.addExit(cfg.fCurrent, loopStart); fLoopContinues.pop(); fLoopExits.pop(); cfg.fCurrent = loopExit; diff --git a/src/sksl/SkSLCFGGenerator.h b/src/sksl/SkSLCFGGenerator.h index c37850112c..337fdfac35 100644 --- a/src/sksl/SkSLCFGGenerator.h +++ b/src/sksl/SkSLCFGGenerator.h @@ -27,14 +27,23 @@ struct BasicBlock { }; Kind fKind; - const IRNode* fNode; + // if false, this node should not be subject to constant propagation. This happens with + // compound assignment (i.e. x *= 2), in which the value x is used as an rvalue for + // multiplication by 2 and then as an lvalue for assignment purposes. Since there is only + // one "x" node, replacing it with a constant would break the assignment and we suppress + // it. Down the road, we should handle this more elegantly by substituting a regular + // assignment if the target is constant (i.e. x = 1; x *= 2; should become x = 1; x = 1 * 2; + // and then collapse down to a simple x = 2;). + bool fConstantPropagation; + std::unique_ptr* fExpression; + const Statement* fStatement; }; - + std::vector fNodes; std::set fEntrances; std::set fExits; // variable definitions upon entering this basic block (null expression = undefined) - std::unordered_map fBefore; + DefinitionMap fBefore; }; struct CFG { @@ -77,9 +86,9 @@ public: private: void addStatement(CFG& cfg, const Statement* s); - void addExpression(CFG& cfg, const Expression* e); + void addExpression(CFG& cfg, std::unique_ptr* e, bool constantPropagate); - void addLValue(CFG& cfg, const Expression* e); + void addLValue(CFG& cfg, std::unique_ptr* e); std::stack fLoopContinues; std::stack fLoopExits; diff --git a/src/sksl/SkSLCompiler.cpp b/src/sksl/SkSLCompiler.cpp index 9faf836156..743745ad14 100644 --- a/src/sksl/SkSLCompiler.cpp +++ b/src/sksl/SkSLCompiler.cpp @@ -156,8 +156,8 @@ Compiler::~Compiler() { } // add the definition created by assigning to the lvalue to the definition set -void Compiler::addDefinition(const Expression* lvalue, const Expression* expr, - std::unordered_map* definitions) { +void Compiler::addDefinition(const Expression* lvalue, std::unique_ptr* expr, + DefinitionMap* definitions) { switch (lvalue->fKind) { case Expression::kVariableReference_Kind: { const Variable& var = ((VariableReference*) lvalue)->fVariable; @@ -174,19 +174,19 @@ void Compiler::addDefinition(const Expression* lvalue, const Expression* expr, // but since we pass foo as a whole it is flagged as an error) unless we perform a much // more complicated whole-program analysis. This is probably good enough. this->addDefinition(((Swizzle*) lvalue)->fBase.get(), - fContext.fDefined_Expression.get(), + (std::unique_ptr*) &fContext.fDefined_Expression, definitions); break; case Expression::kIndex_Kind: // see comments in Swizzle this->addDefinition(((IndexExpression*) lvalue)->fBase.get(), - fContext.fDefined_Expression.get(), + (std::unique_ptr*) &fContext.fDefined_Expression, definitions); break; case Expression::kFieldAccess_Kind: // see comments in Swizzle this->addDefinition(((FieldAccess*) lvalue)->fBase.get(), - fContext.fDefined_Expression.get(), + (std::unique_ptr*) &fContext.fDefined_Expression, definitions); break; default: @@ -197,25 +197,58 @@ void Compiler::addDefinition(const Expression* lvalue, const Expression* expr, // add local variables defined by this node to the set void Compiler::addDefinitions(const BasicBlock::Node& node, - std::unordered_map* definitions) { + DefinitionMap* definitions) { switch (node.fKind) { case BasicBlock::Node::kExpression_Kind: { - const Expression* expr = (Expression*) node.fNode; - if (expr->fKind == Expression::kBinary_Kind) { - const BinaryExpression* b = (BinaryExpression*) expr; - if (b->fOperator == Token::EQ) { - this->addDefinition(b->fLeft.get(), b->fRight.get(), definitions); + ASSERT(node.fExpression); + const Expression* expr = (Expression*) node.fExpression->get(); + switch (expr->fKind) { + case Expression::kBinary_Kind: { + BinaryExpression* b = (BinaryExpression*) expr; + if (b->fOperator == Token::EQ) { + this->addDefinition(b->fLeft.get(), &b->fRight, definitions); + } else if (Token::IsAssignment(b->fOperator)) { + this->addDefinition( + b->fLeft.get(), + (std::unique_ptr*) &fContext.fDefined_Expression, + definitions); + + } + break; + } + case Expression::kPrefix_Kind: { + const PrefixExpression* p = (PrefixExpression*) expr; + if (p->fOperator == Token::MINUSMINUS || p->fOperator == Token::PLUSPLUS) { + this->addDefinition( + p->fOperand.get(), + (std::unique_ptr*) &fContext.fDefined_Expression, + definitions); + } + break; } + case Expression::kPostfix_Kind: { + const PostfixExpression* p = (PostfixExpression*) expr; + if (p->fOperator == Token::MINUSMINUS || p->fOperator == Token::PLUSPLUS) { + this->addDefinition( + p->fOperand.get(), + (std::unique_ptr*) &fContext.fDefined_Expression, + definitions); + + } + break; + } + default: + break; } break; } case BasicBlock::Node::kStatement_Kind: { - const Statement* stmt = (Statement*) node.fNode; + const Statement* stmt = (Statement*) node.fStatement; if (stmt->fKind == Statement::kVarDeclarations_Kind) { - const VarDeclarationsStatement* vd = (VarDeclarationsStatement*) stmt; - for (const VarDeclaration& decl : vd->fDeclaration->fVars) { + VarDeclarationsStatement* vd = (VarDeclarationsStatement*) stmt; + for (VarDeclaration& decl : vd->fDeclaration->fVars) { if (decl.fValue) { - (*definitions)[decl.fVar] = decl.fValue.get(); + (*definitions)[decl.fVar] = &decl.fValue; } } } @@ -228,7 +261,7 @@ void Compiler::scanCFG(CFG* cfg, BlockId blockId, std::set* workList) { BasicBlock& block = cfg->fBlocks[blockId]; // compute definitions after this block - std::unordered_map after = block.fBefore; + DefinitionMap after = block.fBefore; for (const BasicBlock::Node& n : block.fNodes) { this->addDefinitions(n, &after); } @@ -237,19 +270,20 @@ void Compiler::scanCFG(CFG* cfg, BlockId blockId, std::set* workList) { for (BlockId exitId : block.fExits) { BasicBlock& exit = cfg->fBlocks[exitId]; for (const auto& pair : after) { - const Expression* e1 = pair.second; - if (exit.fBefore.find(pair.first) == exit.fBefore.end()) { + std::unique_ptr* e1 = pair.second; + auto found = exit.fBefore.find(pair.first); + if (found == exit.fBefore.end()) { + // exit has no definition for it, just copy it + workList->insert(exitId); exit.fBefore[pair.first] = e1; } else { - const Expression* e2 = exit.fBefore[pair.first]; + // exit has a (possibly different) value already defined + std::unique_ptr* e2 = exit.fBefore[pair.first]; if (e1 != e2) { // definition has changed, merge and add exit block to worklist workList->insert(exitId); - if (!e1 || !e2) { - exit.fBefore[pair.first] = nullptr; - } else { - exit.fBefore[pair.first] = fContext.fDefined_Expression.get(); - } + exit.fBefore[pair.first] = + (std::unique_ptr*) &fContext.fDefined_Expression; } } } @@ -258,12 +292,13 @@ void Compiler::scanCFG(CFG* cfg, BlockId blockId, std::set* workList) { // returns a map which maps all local variables in the function to null, indicating that their value // is initially unknown -static std::unordered_map compute_start_state(const CFG& cfg) { - std::unordered_map result; +static DefinitionMap compute_start_state(const CFG& cfg) { + DefinitionMap result; for (const auto& block : cfg.fBlocks) { for (const auto& node : block.fNodes) { if (node.fKind == BasicBlock::Node::kStatement_Kind) { - const Statement* s = (Statement*) node.fNode; + ASSERT(node.fStatement); + const Statement* s = node.fStatement; if (s->fKind == Statement::kVarDeclarations_Kind) { const VarDeclarationsStatement* vd = (const VarDeclarationsStatement*) s; for (const VarDeclaration& decl : vd->fDeclaration->fVars) { @@ -295,19 +330,37 @@ void Compiler::scanCFG(const FunctionDefinition& f) { for (size_t i = 0; i < cfg.fBlocks.size(); i++) { if (i != cfg.fStart && !cfg.fBlocks[i].fEntrances.size() && cfg.fBlocks[i].fNodes.size()) { - this->error(cfg.fBlocks[i].fNodes[0].fNode->fPosition, SkString("unreachable")); + Position p; + switch (cfg.fBlocks[i].fNodes[0].fKind) { + case BasicBlock::Node::kStatement_Kind: + p = cfg.fBlocks[i].fNodes[0].fStatement->fPosition; + break; + case BasicBlock::Node::kExpression_Kind: + p = (*cfg.fBlocks[i].fNodes[0].fExpression)->fPosition; + break; + } + this->error(p, SkString("unreachable")); } } if (fErrorCount) { return; } - // check for undefined variables - for (const BasicBlock& b : cfg.fBlocks) { - std::unordered_map definitions = b.fBefore; - for (const BasicBlock::Node& n : b.fNodes) { + // check for undefined variables, perform constant propagation + for (BasicBlock& b : cfg.fBlocks) { + DefinitionMap definitions = b.fBefore; + for (BasicBlock::Node& n : b.fNodes) { if (n.fKind == BasicBlock::Node::kExpression_Kind) { - const Expression* expr = (const Expression*) n.fNode; + ASSERT(n.fExpression); + Expression* expr = n.fExpression->get(); + if (n.fConstantPropagation) { + std::unique_ptr optimized = expr->constantPropagate(*fIRGenerator, + definitions); + if (optimized) { + n.fExpression->reset(optimized.release()); + expr = n.fExpression->get(); + } + } if (expr->fKind == Expression::kVariableReference_Kind) { const Variable& var = ((VariableReference*) expr)->fVariable; if (var.fStorage == Variable::kLocal_Storage && diff --git a/src/sksl/SkSLCompiler.h b/src/sksl/SkSLCompiler.h index 0f893f7e64..fdca12d2cf 100644 --- a/src/sksl/SkSLCompiler.h +++ b/src/sksl/SkSLCompiler.h @@ -60,11 +60,10 @@ public: } private: - void addDefinition(const Expression* lvalue, const Expression* expr, - std::unordered_map* definitions); + void addDefinition(const Expression* lvalue, std::unique_ptr* expr, + DefinitionMap* definitions); - void addDefinitions(const BasicBlock::Node& node, - std::unordered_map* definitions); + void addDefinitions(const BasicBlock::Node& node, DefinitionMap* definitions); void scanCFG(CFG* cfg, BlockId block, std::set* workList); diff --git a/src/sksl/SkSLIRGenerator.cpp b/src/sksl/SkSLIRGenerator.cpp index 9f06c97d11..55d9d2c8d6 100644 --- a/src/sksl/SkSLIRGenerator.cpp +++ b/src/sksl/SkSLIRGenerator.cpp @@ -551,11 +551,11 @@ std::unique_ptr IRGenerator::convertInterfaceBlock(const ASTInte } } Type* type = new Type(intf.fPosition, intf.fInterfaceName, fields); - fSymbolTable->takeOwnership(type); + old->takeOwnership(type); SkString name = intf.fValueName.size() > 0 ? intf.fValueName : intf.fInterfaceName; Variable* var = new Variable(intf.fPosition, intf.fModifiers, name, *type, Variable::kGlobal_Storage); - fSymbolTable->takeOwnership(var); + old->takeOwnership(var); if (intf.fValueName.size()) { old->addWithoutOwnership(intf.fValueName, var); } else { @@ -624,19 +624,22 @@ std::unique_ptr IRGenerator::convertIdentifier(const ASTIdentifier& f->fFunctions)); } case Symbol::kVariable_Kind: { - const Variable* var = (const Variable*) result; - this->markReadFrom(*var); + Variable* var = (Variable*) result; if (var->fModifiers.fLayout.fBuiltin == SK_FRAGCOORD_BUILTIN && fSettings->fFlipY && (!fSettings->fCaps || !fSettings->fCaps->fragCoordConventionsExtensionString())) { fInputs.fRTHeight = true; } - return std::unique_ptr(new VariableReference(identifier.fPosition, - *var)); + // default to kRead_RefKind; this will be corrected later if the variable is written to + return std::unique_ptr(new VariableReference( + identifier.fPosition, + *var, + VariableReference::kRead_RefKind)); } case Symbol::kField_Kind: { const Field* field = (const Field*) result; - VariableReference* base = new VariableReference(identifier.fPosition, field->fOwner); + VariableReference* base = new VariableReference(identifier.fPosition, field->fOwner, + VariableReference::kRead_RefKind); return std::unique_ptr(new FieldAccess( std::unique_ptr(base), field->fFieldIndex, @@ -690,28 +693,6 @@ static bool is_matrix_multiply(const Type& left, const Type& right) { return left.kind() == Type::kVector_Kind && right.kind() == Type::kMatrix_Kind; } -static bool is_assignment(Token::Kind op) { - switch (op) { - case Token::EQ: // fall through - case Token::PLUSEQ: // fall through - case Token::MINUSEQ: // fall through - case Token::STAREQ: // fall through - case Token::SLASHEQ: // fall through - case Token::PERCENTEQ: // fall through - case Token::SHLEQ: // fall through - case Token::SHREQ: // fall through - case Token::BITWISEOREQ: // fall through - case Token::BITWISEXOREQ: // fall through - case Token::BITWISEANDEQ: // fall through - case Token::LOGICALOREQ: // fall through - case Token::LOGICALXOREQ: // fall through - case Token::LOGICALANDEQ: - return true; - default: - return false; - } -} - /** * Determines the operand and result types of a binary expression. Returns true if the expression is * legal, false otherwise. If false, the values of the out parameters are undefined. @@ -842,14 +823,9 @@ static bool determine_binary_type(const Context& context, return false; } -/** - * If both operands are compile-time constants and can be folded, returns an expression representing - * the folded value. Otherwise, returns null. Note that unlike most other functions here, null does - * not represent a compilation error. - */ std::unique_ptr IRGenerator::constantFold(const Expression& left, Token::Kind op, - const Expression& right) { + const Expression& right) const { // Note that we expressly do not worry about precision and overflow here -- we use the maximum // precision to calculate the results and hope the result makes sense. The plan is to move the // Skia caps into SkSL, so we have access to all of them including the precisions of the various @@ -943,15 +919,16 @@ std::unique_ptr IRGenerator::convertBinaryExpression( const Type* rightType; const Type* resultType; if (!determine_binary_type(fContext, expression.fOperator, left->fType, right->fType, &leftType, - &rightType, &resultType, !is_assignment(expression.fOperator))) { + &rightType, &resultType, + !Token::IsAssignment(expression.fOperator))) { fErrors.error(expression.fPosition, "type mismatch: '" + Token::OperatorName(expression.fOperator) + "' cannot operate on '" + left->fType.fName + "', '" + right->fType.fName + "'"); return nullptr; } - if (is_assignment(expression.fOperator)) { - this->markWrittenTo(*left); + if (Token::IsAssignment(expression.fOperator)) { + this->markWrittenTo(*left, expression.fOperator != Token::EQ); } left = this->coerce(std::move(left), *leftType); right = this->coerce(std::move(right), *rightType); @@ -1051,7 +1028,7 @@ std::unique_ptr IRGenerator::call(Position position, return nullptr; } if (arguments[i] && (function.fParameters[i]->fModifiers.fFlags & Modifiers::kOut_Flag)) { - this->markWrittenTo(*arguments[i]); + this->markWrittenTo(*arguments[i], true); } } return std::unique_ptr(new FunctionCall(position, *returnType, function, @@ -1261,7 +1238,7 @@ std::unique_ptr IRGenerator::convertPrefixExpression( "' cannot operate on '" + base->fType.description() + "'"); return nullptr; } - this->markWrittenTo(*base); + this->markWrittenTo(*base, true); break; case Token::MINUSMINUS: if (!base->fType.isNumber()) { @@ -1270,7 +1247,7 @@ std::unique_ptr IRGenerator::convertPrefixExpression( "' cannot operate on '" + base->fType.description() + "'"); return nullptr; } - this->markWrittenTo(*base); + this->markWrittenTo(*base, true); break; case Token::LOGICALNOT: if (base->fType != *fContext.fBool_Type) { @@ -1464,7 +1441,7 @@ std::unique_ptr IRGenerator::convertSuffixExpression( "'++' cannot operate on '" + base->fType.description() + "'"); return nullptr; } - this->markWrittenTo(*base); + this->markWrittenTo(*base, true); return std::unique_ptr(new PostfixExpression(std::move(base), Token::PLUSPLUS)); case ASTSuffix::kPostDecrement_Kind: @@ -1473,7 +1450,7 @@ std::unique_ptr IRGenerator::convertSuffixExpression( "'--' cannot operate on '" + base->fType.description() + "'"); return nullptr; } - this->markWrittenTo(*base); + this->markWrittenTo(*base, true); return std::unique_ptr(new PostfixExpression(std::move(base), Token::MINUSMINUS)); default: @@ -1496,10 +1473,6 @@ void IRGenerator::checkValid(const Expression& expr) { } } -void IRGenerator::markReadFrom(const Variable& var) { - var.fIsReadFrom = true; -} - static bool has_duplicates(const Swizzle& swizzle) { int bits = 0; for (int idx : swizzle.fComponents) { @@ -1513,7 +1486,7 @@ static bool has_duplicates(const Swizzle& swizzle) { return false; } -void IRGenerator::markWrittenTo(const Expression& expr) { +void IRGenerator::markWrittenTo(const Expression& expr, bool readWrite) { switch (expr.fKind) { case Expression::kVariableReference_Kind: { const Variable& var = ((VariableReference&) expr).fVariable; @@ -1521,21 +1494,22 @@ void IRGenerator::markWrittenTo(const Expression& expr) { fErrors.error(expr.fPosition, "cannot modify immutable variable '" + var.fName + "'"); } - var.fIsWrittenTo = true; + ((VariableReference&) expr).setRefKind(readWrite ? VariableReference::kReadWrite_RefKind + : VariableReference::kWrite_RefKind); break; } case Expression::kFieldAccess_Kind: - this->markWrittenTo(*((FieldAccess&) expr).fBase); + this->markWrittenTo(*((FieldAccess&) expr).fBase, readWrite); break; case Expression::kSwizzle_Kind: if (has_duplicates((Swizzle&) expr)) { fErrors.error(expr.fPosition, "cannot write to the same swizzle field more than once"); } - this->markWrittenTo(*((Swizzle&) expr).fBase); + this->markWrittenTo(*((Swizzle&) expr).fBase, readWrite); break; case Expression::kIndex_Kind: - this->markWrittenTo(*((IndexExpression&) expr).fBase); + this->markWrittenTo(*((IndexExpression&) expr).fBase, readWrite); break; default: fErrors.error(expr.fPosition, "cannot assign to '" + expr.description() + "'"); diff --git a/src/sksl/SkSLIRGenerator.h b/src/sksl/SkSLIRGenerator.h index 13b20fbbcc..2ffcb0df26 100644 --- a/src/sksl/SkSLIRGenerator.h +++ b/src/sksl/SkSLIRGenerator.h @@ -88,7 +88,16 @@ public: std::unique_ptr convertModifiersDeclaration( const ASTModifiersDeclaration& m); + /** + * If both operands are compile-time constants and can be folded, returns an expression + * representing the folded value. Otherwise, returns null. Note that unlike most other functions + * here, null does not represent a compilation error. + */ + std::unique_ptr constantFold(const Expression& left, + Token::Kind op, + const Expression& right) const; Program::Inputs fInputs; + const Context& fContext; private: /** @@ -124,11 +133,6 @@ private: std::unique_ptr convertDiscard(const ASTDiscardStatement& d); std::unique_ptr convertDo(const ASTDoStatement& d); std::unique_ptr convertBinaryExpression(const ASTBinaryExpression& expression); - // Returns null if it cannot fold the expression. Note that unlike most other functions here, a - // null return does not represent a compilation error. - std::unique_ptr constantFold(const Expression& left, - Token::Kind op, - const Expression& right); std::unique_ptr convertExtension(const ASTExtension& e); std::unique_ptr convertExpressionStatement(const ASTExpressionStatement& s); std::unique_ptr convertFor(const ASTForStatement& f); @@ -151,10 +155,8 @@ private: std::unique_ptr convertWhile(const ASTWhileStatement& w); void checkValid(const Expression& expr); - void markReadFrom(const Variable& var); - void markWrittenTo(const Expression& expr); + void markWrittenTo(const Expression& expr, bool readWrite); - const Context& fContext; const FunctionDeclaration* fCurrentFunction; const Program::Settings* fSettings; std::unordered_map fCapsMap; diff --git a/src/sksl/SkSLSPIRVCodeGenerator.cpp b/src/sksl/SkSLSPIRVCodeGenerator.cpp index 8afd13688c..d43e4c4035 100644 --- a/src/sksl/SkSLSPIRVCodeGenerator.cpp +++ b/src/sksl/SkSLSPIRVCodeGenerator.cpp @@ -2540,7 +2540,7 @@ void SPIRVCodeGenerator::writeGlobalVars(Program::Kind kind, const VarDeclaratio kind != Program::kFragment_Kind) { continue; } - if (!var->fIsReadFrom && !var->fIsWrittenTo && + if (!var->fReadCount && !var->fWriteCount && !(var->fModifiers.fFlags & (Modifiers::kIn_Flag | Modifiers::kOut_Flag | Modifiers::kUniform_Flag))) { diff --git a/src/sksl/SkSLToken.h b/src/sksl/SkSLToken.h index 5c8c2bd215..197781f2a0 100644 --- a/src/sksl/SkSLToken.h +++ b/src/sksl/SkSLToken.h @@ -160,6 +160,28 @@ struct Token { , fKind(kind) , fText(std::move(text)) {} + static bool IsAssignment(Token::Kind op) { + switch (op) { + case Token::EQ: // fall through + case Token::PLUSEQ: // fall through + case Token::MINUSEQ: // fall through + case Token::STAREQ: // fall through + case Token::SLASHEQ: // fall through + case Token::PERCENTEQ: // fall through + case Token::SHLEQ: // fall through + case Token::SHREQ: // fall through + case Token::BITWISEOREQ: // fall through + case Token::BITWISEXOREQ: // fall through + case Token::BITWISEANDEQ: // fall through + case Token::LOGICALOREQ: // fall through + case Token::LOGICALXOREQ: // fall through + case Token::LOGICALANDEQ: + return true; + default: + return false; + } + } + Position fPosition; Kind fKind; // will be the empty string unless the token has variable text content (identifiers, numeric diff --git a/src/sksl/ir/SkSLBinaryExpression.h b/src/sksl/ir/SkSLBinaryExpression.h index 132513e7f7..de85e4812b 100644 --- a/src/sksl/ir/SkSLBinaryExpression.h +++ b/src/sksl/ir/SkSLBinaryExpression.h @@ -4,17 +4,19 @@ * Use of this source code is governed by a BSD-style license that can be * found in the LICENSE file. */ - + #ifndef SKSL_BINARYEXPRESSION #define SKSL_BINARYEXPRESSION #include "SkSLExpression.h" +#include "SkSLExpression.h" +#include "../SkSLIRGenerator.h" #include "../SkSLToken.h" namespace SkSL { /** - * A binary operation. + * A binary operation. */ struct BinaryExpression : public Expression { BinaryExpression(Position position, std::unique_ptr left, Token::Kind op, @@ -24,14 +26,22 @@ struct BinaryExpression : public Expression { , fOperator(op) , fRight(std::move(right)) {} + virtual std::unique_ptr constantPropagate( + const IRGenerator& irGenerator, + const DefinitionMap& definitions) override { + return irGenerator.constantFold(*fLeft, + fOperator, + *fRight); + } + virtual SkString description() const override { return "(" + fLeft->description() + " " + Token::OperatorName(fOperator) + " " + fRight->description() + ")"; } - const std::unique_ptr fLeft; + std::unique_ptr fLeft; const Token::Kind fOperator; - const std::unique_ptr fRight; + std::unique_ptr fRight; typedef Expression INHERITED; }; diff --git a/src/sksl/ir/SkSLBlock.h b/src/sksl/ir/SkSLBlock.h index f975d160a0..17970fd561 100644 --- a/src/sksl/ir/SkSLBlock.h +++ b/src/sksl/ir/SkSLBlock.h @@ -20,8 +20,8 @@ struct Block : public Statement { Block(Position position, std::vector> statements, const std::shared_ptr symbols) : INHERITED(position, kBlock_Kind) - , fStatements(std::move(statements)) - , fSymbols(std::move(symbols)) {} + , fSymbols(std::move(symbols)) + , fStatements(std::move(statements)) {} SkString description() const override { SkString result("{"); @@ -33,8 +33,10 @@ struct Block : public Statement { return result; } - const std::vector> fStatements; + // it's important to keep fStatements defined after (and thus destroyed before) fSymbols, + // because destroying statements can modify reference counts in symbols const std::shared_ptr fSymbols; + const std::vector> fStatements; typedef Statement INHERITED; }; diff --git a/src/sksl/ir/SkSLConstructor.h b/src/sksl/ir/SkSLConstructor.h index 63c692b88e..691bea123a 100644 --- a/src/sksl/ir/SkSLConstructor.h +++ b/src/sksl/ir/SkSLConstructor.h @@ -9,6 +9,9 @@ #define SKSL_CONSTRUCTOR #include "SkSLExpression.h" +#include "SkSLFloatLiteral.h" +#include "SkSLIntLiteral.h" +#include "SkSLIRGenerator.h" namespace SkSL { @@ -21,6 +24,20 @@ struct Constructor : public Expression { : INHERITED(position, kConstructor_Kind, type) , fArguments(std::move(arguments)) {} + virtual std::unique_ptr constantPropagate( + const IRGenerator& irGenerator, + const DefinitionMap& definitions) override { + if (fArguments.size() == 1 && fArguments[0]->fKind == Expression::kIntLiteral_Kind && + // promote float(1) to 1.0 + fType == *irGenerator.fContext.fFloat_Type) { + int64_t intValue = ((IntLiteral&) *fArguments[0]).fValue; + return std::unique_ptr(new FloatLiteral(irGenerator.fContext, + fPosition, + intValue)); + } + return nullptr; + } + SkString description() const override { SkString result = fType.description() + "("; SkString separator; @@ -42,7 +59,7 @@ struct Constructor : public Expression { return true; } - const std::vector> fArguments; + std::vector> fArguments; typedef Expression INHERITED; }; diff --git a/src/sksl/ir/SkSLDoStatement.h b/src/sksl/ir/SkSLDoStatement.h index 78c0a1b768..e26d3dc974 100644 --- a/src/sksl/ir/SkSLDoStatement.h +++ b/src/sksl/ir/SkSLDoStatement.h @@ -28,7 +28,7 @@ struct DoStatement : public Statement { } const std::unique_ptr fStatement; - const std::unique_ptr fTest; + std::unique_ptr fTest; typedef Statement INHERITED; }; diff --git a/src/sksl/ir/SkSLExpression.h b/src/sksl/ir/SkSLExpression.h index b4ed37c09a..f87d810fc0 100644 --- a/src/sksl/ir/SkSLExpression.h +++ b/src/sksl/ir/SkSLExpression.h @@ -4,17 +4,24 @@ * Use of this source code is governed by a BSD-style license that can be * found in the LICENSE file. */ - + #ifndef SKSL_EXPRESSION #define SKSL_EXPRESSION -#include "SkSLIRNode.h" #include "SkSLType.h" +#include "SkSLVariable.h" + +#include namespace SkSL { +struct Expression; +class IRGenerator; + +typedef std::unordered_map*> DefinitionMap; + /** - * Abstract supertype of all expressions. + * Abstract supertype of all expressions. */ struct Expression : public IRNode { enum Kind { @@ -45,6 +52,18 @@ struct Expression : public IRNode { return false; } + /** + * Given a map of known constant variable values, substitute them in for references to those + * variables occurring in this expression and its subexpressions. Similar simplifications, such + * as folding a constant binary expression down to a single value, may also be performed. + * Returns a new expression which replaces this expression, or null if no replacements were + * made. If a new expression is returned, this expression is no longer valid. + */ + virtual std::unique_ptr constantPropagate(const IRGenerator& irGenerator, + const DefinitionMap& definitions) { + return nullptr; + } + const Kind fKind; const Type& fType; diff --git a/src/sksl/ir/SkSLExpressionStatement.h b/src/sksl/ir/SkSLExpressionStatement.h index 677c647587..088b1c9ad1 100644 --- a/src/sksl/ir/SkSLExpressionStatement.h +++ b/src/sksl/ir/SkSLExpressionStatement.h @@ -25,7 +25,7 @@ struct ExpressionStatement : public Statement { return fExpression->description() + ";"; } - const std::unique_ptr fExpression; + std::unique_ptr fExpression; typedef Statement INHERITED; }; diff --git a/src/sksl/ir/SkSLFieldAccess.h b/src/sksl/ir/SkSLFieldAccess.h index fb727e017e..de26a3f626 100644 --- a/src/sksl/ir/SkSLFieldAccess.h +++ b/src/sksl/ir/SkSLFieldAccess.h @@ -35,7 +35,7 @@ struct FieldAccess : public Expression { return fBase->description() + "." + fBase->fType.fields()[fFieldIndex].fName; } - const std::unique_ptr fBase; + std::unique_ptr fBase; const int fFieldIndex; const OwnerKind fOwnerKind; diff --git a/src/sksl/ir/SkSLForStatement.h b/src/sksl/ir/SkSLForStatement.h index ff03d0d7f9..6f03e2bb36 100644 --- a/src/sksl/ir/SkSLForStatement.h +++ b/src/sksl/ir/SkSLForStatement.h @@ -46,8 +46,8 @@ struct ForStatement : public Statement { } const std::unique_ptr fInitializer; - const std::unique_ptr fTest; - const std::unique_ptr fNext; + std::unique_ptr fTest; + std::unique_ptr fNext; const std::unique_ptr fStatement; const std::shared_ptr fSymbols; diff --git a/src/sksl/ir/SkSLFunctionCall.h b/src/sksl/ir/SkSLFunctionCall.h index 971af366b9..1838076796 100644 --- a/src/sksl/ir/SkSLFunctionCall.h +++ b/src/sksl/ir/SkSLFunctionCall.h @@ -36,7 +36,7 @@ struct FunctionCall : public Expression { } const FunctionDeclaration& fFunction; - const std::vector> fArguments; + std::vector> fArguments; typedef Expression INHERITED; }; diff --git a/src/sksl/ir/SkSLIfStatement.h b/src/sksl/ir/SkSLIfStatement.h index f8beded9e8..8667e932ec 100644 --- a/src/sksl/ir/SkSLIfStatement.h +++ b/src/sksl/ir/SkSLIfStatement.h @@ -32,7 +32,7 @@ struct IfStatement : public Statement { return result; } - const std::unique_ptr fTest; + std::unique_ptr fTest; const std::unique_ptr fIfTrue; const std::unique_ptr fIfFalse; diff --git a/src/sksl/ir/SkSLIndexExpression.h b/src/sksl/ir/SkSLIndexExpression.h index 079dde5e53..d255c7daf6 100644 --- a/src/sksl/ir/SkSLIndexExpression.h +++ b/src/sksl/ir/SkSLIndexExpression.h @@ -55,8 +55,8 @@ struct IndexExpression : public Expression { return fBase->description() + "[" + fIndex->description() + "]"; } - const std::unique_ptr fBase; - const std::unique_ptr fIndex; + std::unique_ptr fBase; + std::unique_ptr fIndex; typedef Expression INHERITED; }; diff --git a/src/sksl/ir/SkSLPostfixExpression.h b/src/sksl/ir/SkSLPostfixExpression.h index 01671b5b88..6c9fafe5a0 100644 --- a/src/sksl/ir/SkSLPostfixExpression.h +++ b/src/sksl/ir/SkSLPostfixExpression.h @@ -25,7 +25,7 @@ struct PostfixExpression : public Expression { return fOperand->description() + Token::OperatorName(fOperator); } - const std::unique_ptr fOperand; + std::unique_ptr fOperand; const Token::Kind fOperator; typedef Expression INHERITED; diff --git a/src/sksl/ir/SkSLPrefixExpression.h b/src/sksl/ir/SkSLPrefixExpression.h index 790c5ab47a..b7db99a0a4 100644 --- a/src/sksl/ir/SkSLPrefixExpression.h +++ b/src/sksl/ir/SkSLPrefixExpression.h @@ -25,7 +25,7 @@ struct PrefixExpression : public Expression { return Token::OperatorName(fOperator) + fOperand->description(); } - const std::unique_ptr fOperand; + std::unique_ptr fOperand; const Token::Kind fOperator; typedef Expression INHERITED; diff --git a/src/sksl/ir/SkSLProgram.h b/src/sksl/ir/SkSLProgram.h index ac49d6dcc7..6a73be6983 100644 --- a/src/sksl/ir/SkSLProgram.h +++ b/src/sksl/ir/SkSLProgram.h @@ -59,8 +59,8 @@ struct Program { , fSettings(settings) , fDefaultPrecision(defaultPrecision) , fContext(context) - , fElements(std::move(elements)) , fSymbols(symbols) + , fElements(std::move(elements)) , fInputs(inputs) {} Kind fKind; @@ -68,8 +68,10 @@ struct Program { // FIXME handle different types; currently it assumes this is for floats Modifiers::Flag fDefaultPrecision; Context* fContext; - std::vector> fElements; + // it's important to keep fElements defined after (and thus destroyed before) fSymbols, + // because destroying elements can modify reference counts in symbols std::shared_ptr fSymbols; + std::vector> fElements; Inputs fInputs; }; diff --git a/src/sksl/ir/SkSLReturnStatement.h b/src/sksl/ir/SkSLReturnStatement.h index c83b45066e..dc5ec9aa9c 100644 --- a/src/sksl/ir/SkSLReturnStatement.h +++ b/src/sksl/ir/SkSLReturnStatement.h @@ -32,7 +32,7 @@ struct ReturnStatement : public Statement { } } - const std::unique_ptr fExpression; + std::unique_ptr fExpression; typedef Statement INHERITED; }; diff --git a/src/sksl/ir/SkSLSwizzle.h b/src/sksl/ir/SkSLSwizzle.h index c9397aec7f..8ad9001ada 100644 --- a/src/sksl/ir/SkSLSwizzle.h +++ b/src/sksl/ir/SkSLSwizzle.h @@ -76,7 +76,7 @@ struct Swizzle : public Expression { return result; } - const std::unique_ptr fBase; + std::unique_ptr fBase; const std::vector fComponents; typedef Expression INHERITED; diff --git a/src/sksl/ir/SkSLTernaryExpression.h b/src/sksl/ir/SkSLTernaryExpression.h index 4a352536e3..02750049d4 100644 --- a/src/sksl/ir/SkSLTernaryExpression.h +++ b/src/sksl/ir/SkSLTernaryExpression.h @@ -31,9 +31,9 @@ struct TernaryExpression : public Expression { fIfFalse->description() + ")"; } - const std::unique_ptr fTest; - const std::unique_ptr fIfTrue; - const std::unique_ptr fIfFalse; + std::unique_ptr fTest; + std::unique_ptr fIfTrue; + std::unique_ptr fIfFalse; typedef Expression INHERITED; }; diff --git a/src/sksl/ir/SkSLVarDeclarations.h b/src/sksl/ir/SkSLVarDeclarations.h index 295c0b6997..490259a081 100644 --- a/src/sksl/ir/SkSLVarDeclarations.h +++ b/src/sksl/ir/SkSLVarDeclarations.h @@ -72,7 +72,7 @@ struct VarDeclarations : public ProgramElement { } const Type& fBaseType; - const std::vector fVars; + std::vector fVars; typedef ProgramElement INHERITED; }; diff --git a/src/sksl/ir/SkSLVarDeclarationsStatement.h b/src/sksl/ir/SkSLVarDeclarationsStatement.h index 7a29656593..66b570f853 100644 --- a/src/sksl/ir/SkSLVarDeclarationsStatement.h +++ b/src/sksl/ir/SkSLVarDeclarationsStatement.h @@ -18,14 +18,14 @@ namespace SkSL { */ struct VarDeclarationsStatement : public Statement { VarDeclarationsStatement(std::unique_ptr decl) - : INHERITED(decl->fPosition, kVarDeclarations_Kind) + : INHERITED(decl->fPosition, kVarDeclarations_Kind) , fDeclaration(std::move(decl)) {} SkString description() const override { return fDeclaration->description(); } - const std::shared_ptr fDeclaration; + std::shared_ptr fDeclaration; typedef Statement INHERITED; }; diff --git a/src/sksl/ir/SkSLVariable.h b/src/sksl/ir/SkSLVariable.h index 39b8482a7b..2c3391dfa2 100644 --- a/src/sksl/ir/SkSLVariable.h +++ b/src/sksl/ir/SkSLVariable.h @@ -33,8 +33,8 @@ struct Variable : public Symbol { , fModifiers(modifiers) , fType(type) , fStorage(storage) - , fIsReadFrom(false) - , fIsWrittenTo(false) {} + , fReadCount(0) + , fWriteCount(0) {} virtual SkString description() const override { return fModifiers.description() + fType.fName + " " + fName; @@ -44,8 +44,12 @@ struct Variable : public Symbol { const Type& fType; const Storage fStorage; - mutable bool fIsReadFrom; - mutable bool fIsWrittenTo; + // Tracks how many sites read from the variable. If this is zero for a non-out variable (or + // becomes zero during optimization), the variable is dead and may be eliminated. + mutable int fReadCount; + // Tracks how many sites write to the variable. If this is zero, the variable is dead and may be + // eliminated. + mutable int fWriteCount; typedef Symbol INHERITED; }; diff --git a/src/sksl/ir/SkSLVariableReference.h b/src/sksl/ir/SkSLVariableReference.h index c6a2ea0511..fecb04e2e5 100644 --- a/src/sksl/ir/SkSLVariableReference.h +++ b/src/sksl/ir/SkSLVariableReference.h @@ -20,16 +20,83 @@ namespace SkSL { * there is only one Variable 'x', but two VariableReferences to it. */ struct VariableReference : public Expression { - VariableReference(Position position, const Variable& variable) + enum RefKind { + kRead_RefKind, + kWrite_RefKind, + kReadWrite_RefKind + }; + + VariableReference(Position position, const Variable& variable, RefKind refKind = kRead_RefKind) : INHERITED(position, kVariableReference_Kind, variable.fType) - , fVariable(variable) {} + , fVariable(variable) + , fRefKind(refKind) { + if (refKind != kRead_RefKind) { + fVariable.fWriteCount++; + } + if (refKind != kWrite_RefKind) { + fVariable.fReadCount++; + } + } + + virtual ~VariableReference() override { + if (fRefKind != kWrite_RefKind) { + fVariable.fReadCount--; + } + } + + RefKind refKind() { + return fRefKind; + } + + void setRefKind(RefKind refKind) { + if (fRefKind != kRead_RefKind) { + fVariable.fWriteCount--; + } + if (fRefKind != kWrite_RefKind) { + fVariable.fReadCount--; + } + if (refKind != kRead_RefKind) { + fVariable.fWriteCount++; + } + if (refKind != kWrite_RefKind) { + fVariable.fReadCount++; + } + fRefKind = refKind; + } SkString description() const override { return fVariable.fName; } + virtual std::unique_ptr constantPropagate( + const IRGenerator& irGenerator, + const DefinitionMap& definitions) override { + auto exprIter = definitions.find(&fVariable); + if (exprIter != definitions.end() && exprIter->second) { + const Expression* expr = exprIter->second->get(); + switch (expr->fKind) { + case Expression::kIntLiteral_Kind: + return std::unique_ptr(new IntLiteral( + irGenerator.fContext, + Position(), + ((IntLiteral*) expr)->fValue)); + case Expression::kFloatLiteral_Kind: + return std::unique_ptr(new FloatLiteral( + irGenerator.fContext, + Position(), + ((FloatLiteral*) expr)->fValue)); + default: + break; + } + } + return nullptr; + } + const Variable& fVariable; +private: + RefKind fRefKind; + typedef Expression INHERITED; }; diff --git a/src/sksl/ir/SkSLWhileStatement.h b/src/sksl/ir/SkSLWhileStatement.h index 7c6a2907c4..a741a0441d 100644 --- a/src/sksl/ir/SkSLWhileStatement.h +++ b/src/sksl/ir/SkSLWhileStatement.h @@ -27,7 +27,7 @@ struct WhileStatement : public Statement { return "while (" + fTest->description() + ") " + fStatement->description(); } - const std::unique_ptr fTest; + std::unique_ptr fTest; const std::unique_ptr fStatement; typedef Statement INHERITED; -- cgit v1.2.3