aboutsummaryrefslogtreecommitdiffhomepage
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/sksl/SkSLCFGGenerator.cpp159
-rw-r--r--src/sksl/SkSLCFGGenerator.h19
-rw-r--r--src/sksl/SkSLCompiler.cpp119
-rw-r--r--src/sksl/SkSLCompiler.h7
-rw-r--r--src/sksl/SkSLIRGenerator.cpp78
-rw-r--r--src/sksl/SkSLIRGenerator.h18
-rw-r--r--src/sksl/SkSLSPIRVCodeGenerator.cpp2
-rw-r--r--src/sksl/SkSLToken.h22
-rw-r--r--src/sksl/ir/SkSLBinaryExpression.h18
-rw-r--r--src/sksl/ir/SkSLBlock.h8
-rw-r--r--src/sksl/ir/SkSLConstructor.h19
-rw-r--r--src/sksl/ir/SkSLDoStatement.h2
-rw-r--r--src/sksl/ir/SkSLExpression.h25
-rw-r--r--src/sksl/ir/SkSLExpressionStatement.h2
-rw-r--r--src/sksl/ir/SkSLFieldAccess.h2
-rw-r--r--src/sksl/ir/SkSLForStatement.h4
-rw-r--r--src/sksl/ir/SkSLFunctionCall.h2
-rw-r--r--src/sksl/ir/SkSLIfStatement.h2
-rw-r--r--src/sksl/ir/SkSLIndexExpression.h4
-rw-r--r--src/sksl/ir/SkSLPostfixExpression.h2
-rw-r--r--src/sksl/ir/SkSLPrefixExpression.h2
-rw-r--r--src/sksl/ir/SkSLProgram.h6
-rw-r--r--src/sksl/ir/SkSLReturnStatement.h2
-rw-r--r--src/sksl/ir/SkSLSwizzle.h2
-rw-r--r--src/sksl/ir/SkSLTernaryExpression.h6
-rw-r--r--src/sksl/ir/SkSLVarDeclarations.h2
-rw-r--r--src/sksl/ir/SkSLVarDeclarationsStatement.h4
-rw-r--r--src/sksl/ir/SkSLVariable.h12
-rw-r--r--src/sksl/ir/SkSLVariableReference.h71
-rw-r--r--src/sksl/ir/SkSLWhileStatement.h2
30 files changed, 412 insertions, 211 deletions
diff --git a/src/sksl/SkSLCFGGenerator.cpp b/src/sksl/SkSLCFGGenerator.cpp
index 964a8dc84a..31bace9fb7 100644
--- a/src/sksl/SkSLCFGGenerator.cpp
+++ b/src/sksl/SkSLCFGGenerator.cpp
@@ -54,8 +54,8 @@ void CFG::dump() {
printf("Block %d\n-------\nBefore: ", (int) i);
const char* separator = "";
for (auto iter = fBlocks[i].fBefore.begin(); iter != fBlocks[i].fBefore.end(); iter++) {
- printf("%s%s = %s", separator, iter->first->description().c_str(),
- iter->second ? iter->second->description().c_str() : "<undefined>");
+ printf("%s%s = %s", separator, iter->first->description().c_str(),
+ *iter->second ? (*iter->second)->description().c_str() : "<undefined>");
separator = ", ";
}
printf("\nEntrances: ");
@@ -66,7 +66,10 @@ void CFG::dump() {
}
printf("\n");
for (size_t j = 0; j < fBlocks[i].fNodes.size(); j++) {
- printf("Node %d: %s\n", (int) j, fBlocks[i].fNodes[j].fNode->description().c_str());
+ BasicBlock::Node& n = fBlocks[i].fNodes[j];
+ printf("Node %d: %s\n", (int) j, n.fKind == BasicBlock::Node::kExpression_Kind
+ ? (*n.fExpression)->description().c_str()
+ : n.fStatement->description().c_str());
}
printf("Exits: ");
separator = "";
@@ -78,96 +81,109 @@ void CFG::dump() {
}
}
-void CFGGenerator::addExpression(CFG& cfg, const Expression* e) {
- switch (e->fKind) {
+void CFGGenerator::addExpression(CFG& cfg, std::unique_ptr<Expression>* e, bool constantPropagate) {
+ ASSERT(e);
+ switch ((*e)->fKind) {
case Expression::kBinary_Kind: {
- const BinaryExpression* b = (const BinaryExpression*) e;
+ BinaryExpression* b = (BinaryExpression*) e->get();
switch (b->fOperator) {
case Token::LOGICALAND: // fall through
case Token::LOGICALOR: {
// this isn't as precise as it could be -- we don't bother to track that if we
// early exit from a logical and/or, we know which branch of an 'if' we're going
// to hit -- but it won't make much difference in practice.
- this->addExpression(cfg, b->fLeft.get());
+ this->addExpression(cfg, &b->fLeft, constantPropagate);
BlockId start = cfg.fCurrent;
cfg.newBlock();
- this->addExpression(cfg, b->fRight.get());
+ this->addExpression(cfg, &b->fRight, constantPropagate);
cfg.newBlock();
cfg.addExit(start, cfg.fCurrent);
break;
}
case Token::EQ: {
- this->addExpression(cfg, b->fRight.get());
- this->addLValue(cfg, b->fLeft.get());
- cfg.fBlocks[cfg.fCurrent].fNodes.push_back({
- BasicBlock::Node::kExpression_Kind,
- b
+ this->addExpression(cfg, &b->fRight, constantPropagate);
+ this->addLValue(cfg, &b->fLeft);
+ cfg.fBlocks[cfg.fCurrent].fNodes.push_back({
+ BasicBlock::Node::kExpression_Kind,
+ constantPropagate,
+ e,
+ nullptr
});
break;
}
default:
- this->addExpression(cfg, b->fLeft.get());
- this->addExpression(cfg, b->fRight.get());
- cfg.fBlocks[cfg.fCurrent].fNodes.push_back({
- BasicBlock::Node::kExpression_Kind,
- b
+ this->addExpression(cfg, &b->fLeft, !Token::IsAssignment(b->fOperator));
+ this->addExpression(cfg, &b->fRight, constantPropagate);
+ cfg.fBlocks[cfg.fCurrent].fNodes.push_back({
+ BasicBlock::Node::kExpression_Kind,
+ constantPropagate,
+ e,
+ nullptr
});
}
break;
}
case Expression::kConstructor_Kind: {
- const Constructor* c = (const Constructor*) e;
- for (const auto& arg : c->fArguments) {
- this->addExpression(cfg, arg.get());
+ Constructor* c = (Constructor*) e->get();
+ for (auto& arg : c->fArguments) {
+ this->addExpression(cfg, &arg, constantPropagate);
}
- cfg.fBlocks[cfg.fCurrent].fNodes.push_back({ BasicBlock::Node::kExpression_Kind, c });
+ cfg.fBlocks[cfg.fCurrent].fNodes.push_back({ BasicBlock::Node::kExpression_Kind,
+ constantPropagate, e, nullptr });
break;
}
case Expression::kFunctionCall_Kind: {
- const FunctionCall* c = (const FunctionCall*) e;
- for (const auto& arg : c->fArguments) {
- this->addExpression(cfg, arg.get());
+ FunctionCall* c = (FunctionCall*) e->get();
+ for (auto& arg : c->fArguments) {
+ this->addExpression(cfg, &arg, constantPropagate);
}
- cfg.fBlocks[cfg.fCurrent].fNodes.push_back({ BasicBlock::Node::kExpression_Kind, c });
+ cfg.fBlocks[cfg.fCurrent].fNodes.push_back({ BasicBlock::Node::kExpression_Kind,
+ constantPropagate, e, nullptr });
break;
}
case Expression::kFieldAccess_Kind:
- this->addExpression(cfg, ((const FieldAccess*) e)->fBase.get());
- cfg.fBlocks[cfg.fCurrent].fNodes.push_back({ BasicBlock::Node::kExpression_Kind, e });
+ this->addExpression(cfg, &((FieldAccess*) e->get())->fBase, constantPropagate);
+ cfg.fBlocks[cfg.fCurrent].fNodes.push_back({ BasicBlock::Node::kExpression_Kind,
+ constantPropagate, e, nullptr });
break;
case Expression::kIndex_Kind:
- this->addExpression(cfg, ((const IndexExpression*) e)->fBase.get());
- this->addExpression(cfg, ((const IndexExpression*) e)->fIndex.get());
- cfg.fBlocks[cfg.fCurrent].fNodes.push_back({ BasicBlock::Node::kExpression_Kind, e });
+ this->addExpression(cfg, &((IndexExpression*) e->get())->fBase, constantPropagate);
+ this->addExpression(cfg, &((IndexExpression*) e->get())->fIndex, constantPropagate);
+ cfg.fBlocks[cfg.fCurrent].fNodes.push_back({ BasicBlock::Node::kExpression_Kind,
+ constantPropagate, e, nullptr });
break;
case Expression::kPrefix_Kind:
- this->addExpression(cfg, ((const PrefixExpression*) e)->fOperand.get());
- cfg.fBlocks[cfg.fCurrent].fNodes.push_back({ BasicBlock::Node::kExpression_Kind, e });
+ this->addExpression(cfg, &((PrefixExpression*) e->get())->fOperand, constantPropagate);
+ cfg.fBlocks[cfg.fCurrent].fNodes.push_back({ BasicBlock::Node::kExpression_Kind,
+ constantPropagate, e, nullptr });
break;
case Expression::kPostfix_Kind:
- this->addExpression(cfg, ((const PostfixExpression*) e)->fOperand.get());
- cfg.fBlocks[cfg.fCurrent].fNodes.push_back({ BasicBlock::Node::kExpression_Kind, e });
+ this->addExpression(cfg, &((PostfixExpression*) e->get())->fOperand, constantPropagate);
+ cfg.fBlocks[cfg.fCurrent].fNodes.push_back({ BasicBlock::Node::kExpression_Kind,
+ constantPropagate, e, nullptr });
break;
case Expression::kSwizzle_Kind:
- this->addExpression(cfg, ((const Swizzle*) e)->fBase.get());
- cfg.fBlocks[cfg.fCurrent].fNodes.push_back({ BasicBlock::Node::kExpression_Kind, e });
+ this->addExpression(cfg, &((Swizzle*) e->get())->fBase, constantPropagate);
+ cfg.fBlocks[cfg.fCurrent].fNodes.push_back({ BasicBlock::Node::kExpression_Kind,
+ constantPropagate, e, nullptr });
break;
case Expression::kBoolLiteral_Kind: // fall through
case Expression::kFloatLiteral_Kind: // fall through
case Expression::kIntLiteral_Kind: // fall through
case Expression::kVariableReference_Kind:
- cfg.fBlocks[cfg.fCurrent].fNodes.push_back({ BasicBlock::Node::kExpression_Kind, e });
+ cfg.fBlocks[cfg.fCurrent].fNodes.push_back({ BasicBlock::Node::kExpression_Kind,
+ constantPropagate, e, nullptr });
break;
case Expression::kTernary_Kind: {
- const TernaryExpression* t = (const TernaryExpression*) e;
- this->addExpression(cfg, t->fTest.get());
+ TernaryExpression* t = (TernaryExpression*) e->get();
+ this->addExpression(cfg, &t->fTest, constantPropagate);
BlockId start = cfg.fCurrent;
cfg.newBlock();
- this->addExpression(cfg, t->fIfTrue.get());
+ this->addExpression(cfg, &t->fIfTrue, constantPropagate);
BlockId next = cfg.newBlock();
cfg.fCurrent = start;
cfg.newBlock();
- this->addExpression(cfg, t->fIfFalse.get());
+ this->addExpression(cfg, &t->fIfFalse, constantPropagate);
cfg.addExit(cfg.fCurrent, next);
cfg.fCurrent = next;
break;
@@ -181,17 +197,17 @@ void CFGGenerator::addExpression(CFG& cfg, const Expression* e) {
}
// adds expressions that are evaluated as part of resolving an lvalue
-void CFGGenerator::addLValue(CFG& cfg, const Expression* e) {
- switch (e->fKind) {
+void CFGGenerator::addLValue(CFG& cfg, std::unique_ptr<Expression>* e) {
+ switch ((*e)->fKind) {
case Expression::kFieldAccess_Kind:
- this->addLValue(cfg, ((const FieldAccess*) e)->fBase.get());
+ this->addLValue(cfg, &((FieldAccess&) **e).fBase);
break;
case Expression::kIndex_Kind:
- this->addLValue(cfg, ((const IndexExpression*) e)->fBase.get());
- this->addExpression(cfg, ((const IndexExpression*) e)->fIndex.get());
+ this->addLValue(cfg, &((IndexExpression&) **e).fBase);
+ this->addExpression(cfg, &((IndexExpression&) **e).fIndex, true);
break;
case Expression::kSwizzle_Kind:
- this->addLValue(cfg, ((const Swizzle*) e)->fBase.get());
+ this->addLValue(cfg, &((Swizzle&) **e).fBase);
break;
case Expression::kVariableReference_Kind:
break;
@@ -210,8 +226,8 @@ void CFGGenerator::addStatement(CFG& cfg, const Statement* s) {
}
break;
case Statement::kIf_Kind: {
- const IfStatement* ifs = (const IfStatement*) s;
- this->addExpression(cfg, ifs->fTest.get());
+ IfStatement* ifs = (IfStatement*) s;
+ this->addExpression(cfg, &ifs->fTest, true);
BlockId start = cfg.fCurrent;
cfg.newBlock();
this->addStatement(cfg, ifs->fIfTrue.get());
@@ -228,49 +244,54 @@ void CFGGenerator::addStatement(CFG& cfg, const Statement* s) {
break;
}
case Statement::kExpression_Kind: {
- this->addExpression(cfg, ((ExpressionStatement&) *s).fExpression.get());
+ this->addExpression(cfg, &((ExpressionStatement&) *s).fExpression, true);
break;
}
case Statement::kVarDeclarations_Kind: {
- const VarDeclarationsStatement& decls = ((VarDeclarationsStatement&) *s);
- for (const auto& vd : decls.fDeclaration->fVars) {
+ VarDeclarationsStatement& decls = ((VarDeclarationsStatement&) *s);
+ for (auto& vd : decls.fDeclaration->fVars) {
if (vd.fValue) {
- this->addExpression(cfg, vd.fValue.get());
+ this->addExpression(cfg, &vd.fValue, true);
}
}
- cfg.fBlocks[cfg.fCurrent].fNodes.push_back({ BasicBlock::Node::kStatement_Kind, s });
+ cfg.fBlocks[cfg.fCurrent].fNodes.push_back({ BasicBlock::Node::kStatement_Kind, false,
+ nullptr, s });
break;
}
case Statement::kDiscard_Kind:
- cfg.fBlocks[cfg.fCurrent].fNodes.push_back({ BasicBlock::Node::kStatement_Kind, s });
+ cfg.fBlocks[cfg.fCurrent].fNodes.push_back({ BasicBlock::Node::kStatement_Kind, false,
+ nullptr, s });
cfg.fCurrent = cfg.newIsolatedBlock();
break;
case Statement::kReturn_Kind: {
- const ReturnStatement& r = ((ReturnStatement&) *s);
+ ReturnStatement& r = ((ReturnStatement&) *s);
if (r.fExpression) {
- this->addExpression(cfg, r.fExpression.get());
+ this->addExpression(cfg, &r.fExpression, true);
}
- cfg.fBlocks[cfg.fCurrent].fNodes.push_back({ BasicBlock::Node::kStatement_Kind, s });
+ cfg.fBlocks[cfg.fCurrent].fNodes.push_back({ BasicBlock::Node::kStatement_Kind, false,
+ nullptr, s });
cfg.fCurrent = cfg.newIsolatedBlock();
break;
}
case Statement::kBreak_Kind:
- cfg.fBlocks[cfg.fCurrent].fNodes.push_back({ BasicBlock::Node::kStatement_Kind, s });
+ cfg.fBlocks[cfg.fCurrent].fNodes.push_back({ BasicBlock::Node::kStatement_Kind, false,
+ nullptr, s });
cfg.addExit(cfg.fCurrent, fLoopExits.top());
cfg.fCurrent = cfg.newIsolatedBlock();
break;
case Statement::kContinue_Kind:
- cfg.fBlocks[cfg.fCurrent].fNodes.push_back({ BasicBlock::Node::kStatement_Kind, s });
+ cfg.fBlocks[cfg.fCurrent].fNodes.push_back({ BasicBlock::Node::kStatement_Kind, false,
+ nullptr, s });
cfg.addExit(cfg.fCurrent, fLoopContinues.top());
cfg.fCurrent = cfg.newIsolatedBlock();
break;
case Statement::kWhile_Kind: {
- const WhileStatement* w = (const WhileStatement*) s;
+ WhileStatement* w = (WhileStatement*) s;
BlockId loopStart = cfg.newBlock();
fLoopContinues.push(loopStart);
BlockId loopExit = cfg.newIsolatedBlock();
fLoopExits.push(loopExit);
- this->addExpression(cfg, w->fTest.get());
+ this->addExpression(cfg, &w->fTest, true);
BlockId test = cfg.fCurrent;
cfg.addExit(test, loopExit);
cfg.newBlock();
@@ -282,13 +303,13 @@ void CFGGenerator::addStatement(CFG& cfg, const Statement* s) {
break;
}
case Statement::kDo_Kind: {
- const DoStatement* d = (const DoStatement*) s;
+ DoStatement* d = (DoStatement*) s;
BlockId loopStart = cfg.newBlock();
fLoopContinues.push(loopStart);
BlockId loopExit = cfg.newIsolatedBlock();
fLoopExits.push(loopExit);
this->addStatement(cfg, d->fStatement.get());
- this->addExpression(cfg, d->fTest.get());
+ this->addExpression(cfg, &d->fTest, true);
cfg.addExit(cfg.fCurrent, loopExit);
cfg.addExit(cfg.fCurrent, loopStart);
fLoopContinues.pop();
@@ -297,7 +318,7 @@ void CFGGenerator::addStatement(CFG& cfg, const Statement* s) {
break;
}
case Statement::kFor_Kind: {
- const ForStatement* f = (const ForStatement*) s;
+ ForStatement* f = (ForStatement*) s;
if (f->fInitializer) {
this->addStatement(cfg, f->fInitializer.get());
}
@@ -307,7 +328,7 @@ void CFGGenerator::addStatement(CFG& cfg, const Statement* s) {
BlockId loopExit = cfg.newIsolatedBlock();
fLoopExits.push(loopExit);
if (f->fTest) {
- this->addExpression(cfg, f->fTest.get());
+ this->addExpression(cfg, &f->fTest, true);
BlockId test = cfg.fCurrent;
cfg.addExit(test, loopExit);
}
@@ -316,9 +337,9 @@ void CFGGenerator::addStatement(CFG& cfg, const Statement* s) {
cfg.addExit(cfg.fCurrent, next);
cfg.fCurrent = next;
if (f->fNext) {
- this->addExpression(cfg, f->fNext.get());
+ this->addExpression(cfg, &f->fNext, true);
}
- cfg.addExit(next, loopStart);
+ cfg.addExit(cfg.fCurrent, loopStart);
fLoopContinues.pop();
fLoopExits.pop();
cfg.fCurrent = loopExit;
diff --git a/src/sksl/SkSLCFGGenerator.h b/src/sksl/SkSLCFGGenerator.h
index c37850112c..337fdfac35 100644
--- a/src/sksl/SkSLCFGGenerator.h
+++ b/src/sksl/SkSLCFGGenerator.h
@@ -27,14 +27,23 @@ struct BasicBlock {
};
Kind fKind;
- const IRNode* fNode;
+ // if false, this node should not be subject to constant propagation. This happens with
+ // compound assignment (i.e. x *= 2), in which the value x is used as an rvalue for
+ // multiplication by 2 and then as an lvalue for assignment purposes. Since there is only
+ // one "x" node, replacing it with a constant would break the assignment and we suppress
+ // it. Down the road, we should handle this more elegantly by substituting a regular
+ // assignment if the target is constant (i.e. x = 1; x *= 2; should become x = 1; x = 1 * 2;
+ // and then collapse down to a simple x = 2;).
+ bool fConstantPropagation;
+ std::unique_ptr<Expression>* fExpression;
+ const Statement* fStatement;
};
-
+
std::vector<Node> fNodes;
std::set<BlockId> fEntrances;
std::set<BlockId> fExits;
// variable definitions upon entering this basic block (null expression = undefined)
- std::unordered_map<const Variable*, const Expression*> fBefore;
+ DefinitionMap fBefore;
};
struct CFG {
@@ -77,9 +86,9 @@ public:
private:
void addStatement(CFG& cfg, const Statement* s);
- void addExpression(CFG& cfg, const Expression* e);
+ void addExpression(CFG& cfg, std::unique_ptr<Expression>* e, bool constantPropagate);
- void addLValue(CFG& cfg, const Expression* e);
+ void addLValue(CFG& cfg, std::unique_ptr<Expression>* e);
std::stack<BlockId> fLoopContinues;
std::stack<BlockId> fLoopExits;
diff --git a/src/sksl/SkSLCompiler.cpp b/src/sksl/SkSLCompiler.cpp
index 9faf836156..743745ad14 100644
--- a/src/sksl/SkSLCompiler.cpp
+++ b/src/sksl/SkSLCompiler.cpp
@@ -156,8 +156,8 @@ Compiler::~Compiler() {
}
// add the definition created by assigning to the lvalue to the definition set
-void Compiler::addDefinition(const Expression* lvalue, const Expression* expr,
- std::unordered_map<const Variable*, const Expression*>* definitions) {
+void Compiler::addDefinition(const Expression* lvalue, std::unique_ptr<Expression>* expr,
+ DefinitionMap* definitions) {
switch (lvalue->fKind) {
case Expression::kVariableReference_Kind: {
const Variable& var = ((VariableReference*) lvalue)->fVariable;
@@ -174,19 +174,19 @@ void Compiler::addDefinition(const Expression* lvalue, const Expression* expr,
// but since we pass foo as a whole it is flagged as an error) unless we perform a much
// more complicated whole-program analysis. This is probably good enough.
this->addDefinition(((Swizzle*) lvalue)->fBase.get(),
- fContext.fDefined_Expression.get(),
+ (std::unique_ptr<Expression>*) &fContext.fDefined_Expression,
definitions);
break;
case Expression::kIndex_Kind:
// see comments in Swizzle
this->addDefinition(((IndexExpression*) lvalue)->fBase.get(),
- fContext.fDefined_Expression.get(),
+ (std::unique_ptr<Expression>*) &fContext.fDefined_Expression,
definitions);
break;
case Expression::kFieldAccess_Kind:
// see comments in Swizzle
this->addDefinition(((FieldAccess*) lvalue)->fBase.get(),
- fContext.fDefined_Expression.get(),
+ (std::unique_ptr<Expression>*) &fContext.fDefined_Expression,
definitions);
break;
default:
@@ -197,25 +197,58 @@ void Compiler::addDefinition(const Expression* lvalue, const Expression* expr,
// add local variables defined by this node to the set
void Compiler::addDefinitions(const BasicBlock::Node& node,
- std::unordered_map<const Variable*, const Expression*>* definitions) {
+ DefinitionMap* definitions) {
switch (node.fKind) {
case BasicBlock::Node::kExpression_Kind: {
- const Expression* expr = (Expression*) node.fNode;
- if (expr->fKind == Expression::kBinary_Kind) {
- const BinaryExpression* b = (BinaryExpression*) expr;
- if (b->fOperator == Token::EQ) {
- this->addDefinition(b->fLeft.get(), b->fRight.get(), definitions);
+ ASSERT(node.fExpression);
+ const Expression* expr = (Expression*) node.fExpression->get();
+ switch (expr->fKind) {
+ case Expression::kBinary_Kind: {
+ BinaryExpression* b = (BinaryExpression*) expr;
+ if (b->fOperator == Token::EQ) {
+ this->addDefinition(b->fLeft.get(), &b->fRight, definitions);
+ } else if (Token::IsAssignment(b->fOperator)) {
+ this->addDefinition(
+ b->fLeft.get(),
+ (std::unique_ptr<Expression>*) &fContext.fDefined_Expression,
+ definitions);
+
+ }
+ break;
+ }
+ case Expression::kPrefix_Kind: {
+ const PrefixExpression* p = (PrefixExpression*) expr;
+ if (p->fOperator == Token::MINUSMINUS || p->fOperator == Token::PLUSPLUS) {
+ this->addDefinition(
+ p->fOperand.get(),
+ (std::unique_ptr<Expression>*) &fContext.fDefined_Expression,
+ definitions);
+ }
+ break;
}
+ case Expression::kPostfix_Kind: {
+ const PostfixExpression* p = (PostfixExpression*) expr;
+ if (p->fOperator == Token::MINUSMINUS || p->fOperator == Token::PLUSPLUS) {
+ this->addDefinition(
+ p->fOperand.get(),
+ (std::unique_ptr<Expression>*) &fContext.fDefined_Expression,
+ definitions);
+
+ }
+ break;
+ }
+ default:
+ break;
}
break;
}
case BasicBlock::Node::kStatement_Kind: {
- const Statement* stmt = (Statement*) node.fNode;
+ const Statement* stmt = (Statement*) node.fStatement;
if (stmt->fKind == Statement::kVarDeclarations_Kind) {
- const VarDeclarationsStatement* vd = (VarDeclarationsStatement*) stmt;
- for (const VarDeclaration& decl : vd->fDeclaration->fVars) {
+ VarDeclarationsStatement* vd = (VarDeclarationsStatement*) stmt;
+ for (VarDeclaration& decl : vd->fDeclaration->fVars) {
if (decl.fValue) {
- (*definitions)[decl.fVar] = decl.fValue.get();
+ (*definitions)[decl.fVar] = &decl.fValue;
}
}
}
@@ -228,7 +261,7 @@ void Compiler::scanCFG(CFG* cfg, BlockId blockId, std::set<BlockId>* workList) {
BasicBlock& block = cfg->fBlocks[blockId];
// compute definitions after this block
- std::unordered_map<const Variable*, const Expression*> after = block.fBefore;
+ DefinitionMap after = block.fBefore;
for (const BasicBlock::Node& n : block.fNodes) {
this->addDefinitions(n, &after);
}
@@ -237,19 +270,20 @@ void Compiler::scanCFG(CFG* cfg, BlockId blockId, std::set<BlockId>* workList) {
for (BlockId exitId : block.fExits) {
BasicBlock& exit = cfg->fBlocks[exitId];
for (const auto& pair : after) {
- const Expression* e1 = pair.second;
- if (exit.fBefore.find(pair.first) == exit.fBefore.end()) {
+ std::unique_ptr<Expression>* e1 = pair.second;
+ auto found = exit.fBefore.find(pair.first);
+ if (found == exit.fBefore.end()) {
+ // exit has no definition for it, just copy it
+ workList->insert(exitId);
exit.fBefore[pair.first] = e1;
} else {
- const Expression* e2 = exit.fBefore[pair.first];
+ // exit has a (possibly different) value already defined
+ std::unique_ptr<Expression>* e2 = exit.fBefore[pair.first];
if (e1 != e2) {
// definition has changed, merge and add exit block to worklist
workList->insert(exitId);
- if (!e1 || !e2) {
- exit.fBefore[pair.first] = nullptr;
- } else {
- exit.fBefore[pair.first] = fContext.fDefined_Expression.get();
- }
+ exit.fBefore[pair.first] =
+ (std::unique_ptr<Expression>*) &fContext.fDefined_Expression;
}
}
}
@@ -258,12 +292,13 @@ void Compiler::scanCFG(CFG* cfg, BlockId blockId, std::set<BlockId>* workList) {
// returns a map which maps all local variables in the function to null, indicating that their value
// is initially unknown
-static std::unordered_map<const Variable*, const Expression*> compute_start_state(const CFG& cfg) {
- std::unordered_map<const Variable*, const Expression*> result;
+static DefinitionMap compute_start_state(const CFG& cfg) {
+ DefinitionMap result;
for (const auto& block : cfg.fBlocks) {
for (const auto& node : block.fNodes) {
if (node.fKind == BasicBlock::Node::kStatement_Kind) {
- const Statement* s = (Statement*) node.fNode;
+ ASSERT(node.fStatement);
+ const Statement* s = node.fStatement;
if (s->fKind == Statement::kVarDeclarations_Kind) {
const VarDeclarationsStatement* vd = (const VarDeclarationsStatement*) s;
for (const VarDeclaration& decl : vd->fDeclaration->fVars) {
@@ -295,19 +330,37 @@ void Compiler::scanCFG(const FunctionDefinition& f) {
for (size_t i = 0; i < cfg.fBlocks.size(); i++) {
if (i != cfg.fStart && !cfg.fBlocks[i].fEntrances.size() &&
cfg.fBlocks[i].fNodes.size()) {
- this->error(cfg.fBlocks[i].fNodes[0].fNode->fPosition, SkString("unreachable"));
+ Position p;
+ switch (cfg.fBlocks[i].fNodes[0].fKind) {
+ case BasicBlock::Node::kStatement_Kind:
+ p = cfg.fBlocks[i].fNodes[0].fStatement->fPosition;
+ break;
+ case BasicBlock::Node::kExpression_Kind:
+ p = (*cfg.fBlocks[i].fNodes[0].fExpression)->fPosition;
+ break;
+ }
+ this->error(p, SkString("unreachable"));
}
}
if (fErrorCount) {
return;
}
- // check for undefined variables
- for (const BasicBlock& b : cfg.fBlocks) {
- std::unordered_map<const Variable*, const Expression*> definitions = b.fBefore;
- for (const BasicBlock::Node& n : b.fNodes) {
+ // check for undefined variables, perform constant propagation
+ for (BasicBlock& b : cfg.fBlocks) {
+ DefinitionMap definitions = b.fBefore;
+ for (BasicBlock::Node& n : b.fNodes) {
if (n.fKind == BasicBlock::Node::kExpression_Kind) {
- const Expression* expr = (const Expression*) n.fNode;
+ ASSERT(n.fExpression);
+ Expression* expr = n.fExpression->get();
+ if (n.fConstantPropagation) {
+ std::unique_ptr<Expression> optimized = expr->constantPropagate(*fIRGenerator,
+ definitions);
+ if (optimized) {
+ n.fExpression->reset(optimized.release());
+ expr = n.fExpression->get();
+ }
+ }
if (expr->fKind == Expression::kVariableReference_Kind) {
const Variable& var = ((VariableReference*) expr)->fVariable;
if (var.fStorage == Variable::kLocal_Storage &&
diff --git a/src/sksl/SkSLCompiler.h b/src/sksl/SkSLCompiler.h
index 0f893f7e64..fdca12d2cf 100644
--- a/src/sksl/SkSLCompiler.h
+++ b/src/sksl/SkSLCompiler.h
@@ -60,11 +60,10 @@ public:
}
private:
- void addDefinition(const Expression* lvalue, const Expression* expr,
- std::unordered_map<const Variable*, const Expression*>* definitions);
+ void addDefinition(const Expression* lvalue, std::unique_ptr<Expression>* expr,
+ DefinitionMap* definitions);
- void addDefinitions(const BasicBlock::Node& node,
- std::unordered_map<const Variable*, const Expression*>* definitions);
+ void addDefinitions(const BasicBlock::Node& node, DefinitionMap* definitions);
void scanCFG(CFG* cfg, BlockId block, std::set<BlockId>* workList);
diff --git a/src/sksl/SkSLIRGenerator.cpp b/src/sksl/SkSLIRGenerator.cpp
index 9f06c97d11..55d9d2c8d6 100644
--- a/src/sksl/SkSLIRGenerator.cpp
+++ b/src/sksl/SkSLIRGenerator.cpp
@@ -551,11 +551,11 @@ std::unique_ptr<InterfaceBlock> IRGenerator::convertInterfaceBlock(const ASTInte
}
}
Type* type = new Type(intf.fPosition, intf.fInterfaceName, fields);
- fSymbolTable->takeOwnership(type);
+ old->takeOwnership(type);
SkString name = intf.fValueName.size() > 0 ? intf.fValueName : intf.fInterfaceName;
Variable* var = new Variable(intf.fPosition, intf.fModifiers, name, *type,
Variable::kGlobal_Storage);
- fSymbolTable->takeOwnership(var);
+ old->takeOwnership(var);
if (intf.fValueName.size()) {
old->addWithoutOwnership(intf.fValueName, var);
} else {
@@ -624,19 +624,22 @@ std::unique_ptr<Expression> IRGenerator::convertIdentifier(const ASTIdentifier&
f->fFunctions));
}
case Symbol::kVariable_Kind: {
- const Variable* var = (const Variable*) result;
- this->markReadFrom(*var);
+ Variable* var = (Variable*) result;
if (var->fModifiers.fLayout.fBuiltin == SK_FRAGCOORD_BUILTIN &&
fSettings->fFlipY &&
(!fSettings->fCaps || !fSettings->fCaps->fragCoordConventionsExtensionString())) {
fInputs.fRTHeight = true;
}
- return std::unique_ptr<VariableReference>(new VariableReference(identifier.fPosition,
- *var));
+ // default to kRead_RefKind; this will be corrected later if the variable is written to
+ return std::unique_ptr<VariableReference>(new VariableReference(
+ identifier.fPosition,
+ *var,
+ VariableReference::kRead_RefKind));
}
case Symbol::kField_Kind: {
const Field* field = (const Field*) result;
- VariableReference* base = new VariableReference(identifier.fPosition, field->fOwner);
+ VariableReference* base = new VariableReference(identifier.fPosition, field->fOwner,
+ VariableReference::kRead_RefKind);
return std::unique_ptr<Expression>(new FieldAccess(
std::unique_ptr<Expression>(base),
field->fFieldIndex,
@@ -690,28 +693,6 @@ static bool is_matrix_multiply(const Type& left, const Type& right) {
return left.kind() == Type::kVector_Kind && right.kind() == Type::kMatrix_Kind;
}
-static bool is_assignment(Token::Kind op) {
- switch (op) {
- case Token::EQ: // fall through
- case Token::PLUSEQ: // fall through
- case Token::MINUSEQ: // fall through
- case Token::STAREQ: // fall through
- case Token::SLASHEQ: // fall through
- case Token::PERCENTEQ: // fall through
- case Token::SHLEQ: // fall through
- case Token::SHREQ: // fall through
- case Token::BITWISEOREQ: // fall through
- case Token::BITWISEXOREQ: // fall through
- case Token::BITWISEANDEQ: // fall through
- case Token::LOGICALOREQ: // fall through
- case Token::LOGICALXOREQ: // fall through
- case Token::LOGICALANDEQ:
- return true;
- default:
- return false;
- }
-}
-
/**
* Determines the operand and result types of a binary expression. Returns true if the expression is
* legal, false otherwise. If false, the values of the out parameters are undefined.
@@ -842,14 +823,9 @@ static bool determine_binary_type(const Context& context,
return false;
}
-/**
- * If both operands are compile-time constants and can be folded, returns an expression representing
- * the folded value. Otherwise, returns null. Note that unlike most other functions here, null does
- * not represent a compilation error.
- */
std::unique_ptr<Expression> IRGenerator::constantFold(const Expression& left,
Token::Kind op,
- const Expression& right) {
+ const Expression& right) const {
// Note that we expressly do not worry about precision and overflow here -- we use the maximum
// precision to calculate the results and hope the result makes sense. The plan is to move the
// Skia caps into SkSL, so we have access to all of them including the precisions of the various
@@ -943,15 +919,16 @@ std::unique_ptr<Expression> IRGenerator::convertBinaryExpression(
const Type* rightType;
const Type* resultType;
if (!determine_binary_type(fContext, expression.fOperator, left->fType, right->fType, &leftType,
- &rightType, &resultType, !is_assignment(expression.fOperator))) {
+ &rightType, &resultType,
+ !Token::IsAssignment(expression.fOperator))) {
fErrors.error(expression.fPosition, "type mismatch: '" +
Token::OperatorName(expression.fOperator) +
"' cannot operate on '" + left->fType.fName +
"', '" + right->fType.fName + "'");
return nullptr;
}
- if (is_assignment(expression.fOperator)) {
- this->markWrittenTo(*left);
+ if (Token::IsAssignment(expression.fOperator)) {
+ this->markWrittenTo(*left, expression.fOperator != Token::EQ);
}
left = this->coerce(std::move(left), *leftType);
right = this->coerce(std::move(right), *rightType);
@@ -1051,7 +1028,7 @@ std::unique_ptr<Expression> IRGenerator::call(Position position,
return nullptr;
}
if (arguments[i] && (function.fParameters[i]->fModifiers.fFlags & Modifiers::kOut_Flag)) {
- this->markWrittenTo(*arguments[i]);
+ this->markWrittenTo(*arguments[i], true);
}
}
return std::unique_ptr<FunctionCall>(new FunctionCall(position, *returnType, function,
@@ -1261,7 +1238,7 @@ std::unique_ptr<Expression> IRGenerator::convertPrefixExpression(
"' cannot operate on '" + base->fType.description() + "'");
return nullptr;
}
- this->markWrittenTo(*base);
+ this->markWrittenTo(*base, true);
break;
case Token::MINUSMINUS:
if (!base->fType.isNumber()) {
@@ -1270,7 +1247,7 @@ std::unique_ptr<Expression> IRGenerator::convertPrefixExpression(
"' cannot operate on '" + base->fType.description() + "'");
return nullptr;
}
- this->markWrittenTo(*base);
+ this->markWrittenTo(*base, true);
break;
case Token::LOGICALNOT:
if (base->fType != *fContext.fBool_Type) {
@@ -1464,7 +1441,7 @@ std::unique_ptr<Expression> IRGenerator::convertSuffixExpression(
"'++' cannot operate on '" + base->fType.description() + "'");
return nullptr;
}
- this->markWrittenTo(*base);
+ this->markWrittenTo(*base, true);
return std::unique_ptr<Expression>(new PostfixExpression(std::move(base),
Token::PLUSPLUS));
case ASTSuffix::kPostDecrement_Kind:
@@ -1473,7 +1450,7 @@ std::unique_ptr<Expression> IRGenerator::convertSuffixExpression(
"'--' cannot operate on '" + base->fType.description() + "'");
return nullptr;
}
- this->markWrittenTo(*base);
+ this->markWrittenTo(*base, true);
return std::unique_ptr<Expression>(new PostfixExpression(std::move(base),
Token::MINUSMINUS));
default:
@@ -1496,10 +1473,6 @@ void IRGenerator::checkValid(const Expression& expr) {
}
}
-void IRGenerator::markReadFrom(const Variable& var) {
- var.fIsReadFrom = true;
-}
-
static bool has_duplicates(const Swizzle& swizzle) {
int bits = 0;
for (int idx : swizzle.fComponents) {
@@ -1513,7 +1486,7 @@ static bool has_duplicates(const Swizzle& swizzle) {
return false;
}
-void IRGenerator::markWrittenTo(const Expression& expr) {
+void IRGenerator::markWrittenTo(const Expression& expr, bool readWrite) {
switch (expr.fKind) {
case Expression::kVariableReference_Kind: {
const Variable& var = ((VariableReference&) expr).fVariable;
@@ -1521,21 +1494,22 @@ void IRGenerator::markWrittenTo(const Expression& expr) {
fErrors.error(expr.fPosition,
"cannot modify immutable variable '" + var.fName + "'");
}
- var.fIsWrittenTo = true;
+ ((VariableReference&) expr).setRefKind(readWrite ? VariableReference::kReadWrite_RefKind
+ : VariableReference::kWrite_RefKind);
break;
}
case Expression::kFieldAccess_Kind:
- this->markWrittenTo(*((FieldAccess&) expr).fBase);
+ this->markWrittenTo(*((FieldAccess&) expr).fBase, readWrite);
break;
case Expression::kSwizzle_Kind:
if (has_duplicates((Swizzle&) expr)) {
fErrors.error(expr.fPosition,
"cannot write to the same swizzle field more than once");
}
- this->markWrittenTo(*((Swizzle&) expr).fBase);
+ this->markWrittenTo(*((Swizzle&) expr).fBase, readWrite);
break;
case Expression::kIndex_Kind:
- this->markWrittenTo(*((IndexExpression&) expr).fBase);
+ this->markWrittenTo(*((IndexExpression&) expr).fBase, readWrite);
break;
default:
fErrors.error(expr.fPosition, "cannot assign to '" + expr.description() + "'");
diff --git a/src/sksl/SkSLIRGenerator.h b/src/sksl/SkSLIRGenerator.h
index 13b20fbbcc..2ffcb0df26 100644
--- a/src/sksl/SkSLIRGenerator.h
+++ b/src/sksl/SkSLIRGenerator.h
@@ -88,7 +88,16 @@ public:
std::unique_ptr<ModifiersDeclaration> convertModifiersDeclaration(
const ASTModifiersDeclaration& m);
+ /**
+ * If both operands are compile-time constants and can be folded, returns an expression
+ * representing the folded value. Otherwise, returns null. Note that unlike most other functions
+ * here, null does not represent a compilation error.
+ */
+ std::unique_ptr<Expression> constantFold(const Expression& left,
+ Token::Kind op,
+ const Expression& right) const;
Program::Inputs fInputs;
+ const Context& fContext;
private:
/**
@@ -124,11 +133,6 @@ private:
std::unique_ptr<Statement> convertDiscard(const ASTDiscardStatement& d);
std::unique_ptr<Statement> convertDo(const ASTDoStatement& d);
std::unique_ptr<Expression> convertBinaryExpression(const ASTBinaryExpression& expression);
- // Returns null if it cannot fold the expression. Note that unlike most other functions here, a
- // null return does not represent a compilation error.
- std::unique_ptr<Expression> constantFold(const Expression& left,
- Token::Kind op,
- const Expression& right);
std::unique_ptr<Extension> convertExtension(const ASTExtension& e);
std::unique_ptr<Statement> convertExpressionStatement(const ASTExpressionStatement& s);
std::unique_ptr<Statement> convertFor(const ASTForStatement& f);
@@ -151,10 +155,8 @@ private:
std::unique_ptr<Statement> convertWhile(const ASTWhileStatement& w);
void checkValid(const Expression& expr);
- void markReadFrom(const Variable& var);
- void markWrittenTo(const Expression& expr);
+ void markWrittenTo(const Expression& expr, bool readWrite);
- const Context& fContext;
const FunctionDeclaration* fCurrentFunction;
const Program::Settings* fSettings;
std::unordered_map<SkString, CapValue> fCapsMap;
diff --git a/src/sksl/SkSLSPIRVCodeGenerator.cpp b/src/sksl/SkSLSPIRVCodeGenerator.cpp
index 8afd13688c..d43e4c4035 100644
--- a/src/sksl/SkSLSPIRVCodeGenerator.cpp
+++ b/src/sksl/SkSLSPIRVCodeGenerator.cpp
@@ -2540,7 +2540,7 @@ void SPIRVCodeGenerator::writeGlobalVars(Program::Kind kind, const VarDeclaratio
kind != Program::kFragment_Kind) {
continue;
}
- if (!var->fIsReadFrom && !var->fIsWrittenTo &&
+ if (!var->fReadCount && !var->fWriteCount &&
!(var->fModifiers.fFlags & (Modifiers::kIn_Flag |
Modifiers::kOut_Flag |
Modifiers::kUniform_Flag))) {
diff --git a/src/sksl/SkSLToken.h b/src/sksl/SkSLToken.h
index 5c8c2bd215..197781f2a0 100644
--- a/src/sksl/SkSLToken.h
+++ b/src/sksl/SkSLToken.h
@@ -160,6 +160,28 @@ struct Token {
, fKind(kind)
, fText(std::move(text)) {}
+ static bool IsAssignment(Token::Kind op) {
+ switch (op) {
+ case Token::EQ: // fall through
+ case Token::PLUSEQ: // fall through
+ case Token::MINUSEQ: // fall through
+ case Token::STAREQ: // fall through
+ case Token::SLASHEQ: // fall through
+ case Token::PERCENTEQ: // fall through
+ case Token::SHLEQ: // fall through
+ case Token::SHREQ: // fall through
+ case Token::BITWISEOREQ: // fall through
+ case Token::BITWISEXOREQ: // fall through
+ case Token::BITWISEANDEQ: // fall through
+ case Token::LOGICALOREQ: // fall through
+ case Token::LOGICALXOREQ: // fall through
+ case Token::LOGICALANDEQ:
+ return true;
+ default:
+ return false;
+ }
+ }
+
Position fPosition;
Kind fKind;
// will be the empty string unless the token has variable text content (identifiers, numeric
diff --git a/src/sksl/ir/SkSLBinaryExpression.h b/src/sksl/ir/SkSLBinaryExpression.h
index 132513e7f7..de85e4812b 100644
--- a/src/sksl/ir/SkSLBinaryExpression.h
+++ b/src/sksl/ir/SkSLBinaryExpression.h
@@ -4,17 +4,19 @@
* Use of this source code is governed by a BSD-style license that can be
* found in the LICENSE file.
*/
-
+
#ifndef SKSL_BINARYEXPRESSION
#define SKSL_BINARYEXPRESSION
#include "SkSLExpression.h"
+#include "SkSLExpression.h"
+#include "../SkSLIRGenerator.h"
#include "../SkSLToken.h"
namespace SkSL {
/**
- * A binary operation.
+ * A binary operation.
*/
struct BinaryExpression : public Expression {
BinaryExpression(Position position, std::unique_ptr<Expression> left, Token::Kind op,
@@ -24,14 +26,22 @@ struct BinaryExpression : public Expression {
, fOperator(op)
, fRight(std::move(right)) {}
+ virtual std::unique_ptr<Expression> constantPropagate(
+ const IRGenerator& irGenerator,
+ const DefinitionMap& definitions) override {
+ return irGenerator.constantFold(*fLeft,
+ fOperator,
+ *fRight);
+ }
+
virtual SkString description() const override {
return "(" + fLeft->description() + " " + Token::OperatorName(fOperator) + " " +
fRight->description() + ")";
}
- const std::unique_ptr<Expression> fLeft;
+ std::unique_ptr<Expression> fLeft;
const Token::Kind fOperator;
- const std::unique_ptr<Expression> fRight;
+ std::unique_ptr<Expression> fRight;
typedef Expression INHERITED;
};
diff --git a/src/sksl/ir/SkSLBlock.h b/src/sksl/ir/SkSLBlock.h
index f975d160a0..17970fd561 100644
--- a/src/sksl/ir/SkSLBlock.h
+++ b/src/sksl/ir/SkSLBlock.h
@@ -20,8 +20,8 @@ struct Block : public Statement {
Block(Position position, std::vector<std::unique_ptr<Statement>> statements,
const std::shared_ptr<SymbolTable> symbols)
: INHERITED(position, kBlock_Kind)
- , fStatements(std::move(statements))
- , fSymbols(std::move(symbols)) {}
+ , fSymbols(std::move(symbols))
+ , fStatements(std::move(statements)) {}
SkString description() const override {
SkString result("{");
@@ -33,8 +33,10 @@ struct Block : public Statement {
return result;
}
- const std::vector<std::unique_ptr<Statement>> fStatements;
+ // it's important to keep fStatements defined after (and thus destroyed before) fSymbols,
+ // because destroying statements can modify reference counts in symbols
const std::shared_ptr<SymbolTable> fSymbols;
+ const std::vector<std::unique_ptr<Statement>> fStatements;
typedef Statement INHERITED;
};
diff --git a/src/sksl/ir/SkSLConstructor.h b/src/sksl/ir/SkSLConstructor.h
index 63c692b88e..691bea123a 100644
--- a/src/sksl/ir/SkSLConstructor.h
+++ b/src/sksl/ir/SkSLConstructor.h
@@ -9,6 +9,9 @@
#define SKSL_CONSTRUCTOR
#include "SkSLExpression.h"
+#include "SkSLFloatLiteral.h"
+#include "SkSLIntLiteral.h"
+#include "SkSLIRGenerator.h"
namespace SkSL {
@@ -21,6 +24,20 @@ struct Constructor : public Expression {
: INHERITED(position, kConstructor_Kind, type)
, fArguments(std::move(arguments)) {}
+ virtual std::unique_ptr<Expression> constantPropagate(
+ const IRGenerator& irGenerator,
+ const DefinitionMap& definitions) override {
+ if (fArguments.size() == 1 && fArguments[0]->fKind == Expression::kIntLiteral_Kind &&
+ // promote float(1) to 1.0
+ fType == *irGenerator.fContext.fFloat_Type) {
+ int64_t intValue = ((IntLiteral&) *fArguments[0]).fValue;
+ return std::unique_ptr<Expression>(new FloatLiteral(irGenerator.fContext,
+ fPosition,
+ intValue));
+ }
+ return nullptr;
+ }
+
SkString description() const override {
SkString result = fType.description() + "(";
SkString separator;
@@ -42,7 +59,7 @@ struct Constructor : public Expression {
return true;
}
- const std::vector<std::unique_ptr<Expression>> fArguments;
+ std::vector<std::unique_ptr<Expression>> fArguments;
typedef Expression INHERITED;
};
diff --git a/src/sksl/ir/SkSLDoStatement.h b/src/sksl/ir/SkSLDoStatement.h
index 78c0a1b768..e26d3dc974 100644
--- a/src/sksl/ir/SkSLDoStatement.h
+++ b/src/sksl/ir/SkSLDoStatement.h
@@ -28,7 +28,7 @@ struct DoStatement : public Statement {
}
const std::unique_ptr<Statement> fStatement;
- const std::unique_ptr<Expression> fTest;
+ std::unique_ptr<Expression> fTest;
typedef Statement INHERITED;
};
diff --git a/src/sksl/ir/SkSLExpression.h b/src/sksl/ir/SkSLExpression.h
index b4ed37c09a..f87d810fc0 100644
--- a/src/sksl/ir/SkSLExpression.h
+++ b/src/sksl/ir/SkSLExpression.h
@@ -4,17 +4,24 @@
* Use of this source code is governed by a BSD-style license that can be
* found in the LICENSE file.
*/
-
+
#ifndef SKSL_EXPRESSION
#define SKSL_EXPRESSION
-#include "SkSLIRNode.h"
#include "SkSLType.h"
+#include "SkSLVariable.h"
+
+#include <unordered_map>
namespace SkSL {
+struct Expression;
+class IRGenerator;
+
+typedef std::unordered_map<const Variable*, std::unique_ptr<Expression>*> DefinitionMap;
+
/**
- * Abstract supertype of all expressions.
+ * Abstract supertype of all expressions.
*/
struct Expression : public IRNode {
enum Kind {
@@ -45,6 +52,18 @@ struct Expression : public IRNode {
return false;
}
+ /**
+ * Given a map of known constant variable values, substitute them in for references to those
+ * variables occurring in this expression and its subexpressions. Similar simplifications, such
+ * as folding a constant binary expression down to a single value, may also be performed.
+ * Returns a new expression which replaces this expression, or null if no replacements were
+ * made. If a new expression is returned, this expression is no longer valid.
+ */
+ virtual std::unique_ptr<Expression> constantPropagate(const IRGenerator& irGenerator,
+ const DefinitionMap& definitions) {
+ return nullptr;
+ }
+
const Kind fKind;
const Type& fType;
diff --git a/src/sksl/ir/SkSLExpressionStatement.h b/src/sksl/ir/SkSLExpressionStatement.h
index 677c647587..088b1c9ad1 100644
--- a/src/sksl/ir/SkSLExpressionStatement.h
+++ b/src/sksl/ir/SkSLExpressionStatement.h
@@ -25,7 +25,7 @@ struct ExpressionStatement : public Statement {
return fExpression->description() + ";";
}
- const std::unique_ptr<Expression> fExpression;
+ std::unique_ptr<Expression> fExpression;
typedef Statement INHERITED;
};
diff --git a/src/sksl/ir/SkSLFieldAccess.h b/src/sksl/ir/SkSLFieldAccess.h
index fb727e017e..de26a3f626 100644
--- a/src/sksl/ir/SkSLFieldAccess.h
+++ b/src/sksl/ir/SkSLFieldAccess.h
@@ -35,7 +35,7 @@ struct FieldAccess : public Expression {
return fBase->description() + "." + fBase->fType.fields()[fFieldIndex].fName;
}
- const std::unique_ptr<Expression> fBase;
+ std::unique_ptr<Expression> fBase;
const int fFieldIndex;
const OwnerKind fOwnerKind;
diff --git a/src/sksl/ir/SkSLForStatement.h b/src/sksl/ir/SkSLForStatement.h
index ff03d0d7f9..6f03e2bb36 100644
--- a/src/sksl/ir/SkSLForStatement.h
+++ b/src/sksl/ir/SkSLForStatement.h
@@ -46,8 +46,8 @@ struct ForStatement : public Statement {
}
const std::unique_ptr<Statement> fInitializer;
- const std::unique_ptr<Expression> fTest;
- const std::unique_ptr<Expression> fNext;
+ std::unique_ptr<Expression> fTest;
+ std::unique_ptr<Expression> fNext;
const std::unique_ptr<Statement> fStatement;
const std::shared_ptr<SymbolTable> fSymbols;
diff --git a/src/sksl/ir/SkSLFunctionCall.h b/src/sksl/ir/SkSLFunctionCall.h
index 971af366b9..1838076796 100644
--- a/src/sksl/ir/SkSLFunctionCall.h
+++ b/src/sksl/ir/SkSLFunctionCall.h
@@ -36,7 +36,7 @@ struct FunctionCall : public Expression {
}
const FunctionDeclaration& fFunction;
- const std::vector<std::unique_ptr<Expression>> fArguments;
+ std::vector<std::unique_ptr<Expression>> fArguments;
typedef Expression INHERITED;
};
diff --git a/src/sksl/ir/SkSLIfStatement.h b/src/sksl/ir/SkSLIfStatement.h
index f8beded9e8..8667e932ec 100644
--- a/src/sksl/ir/SkSLIfStatement.h
+++ b/src/sksl/ir/SkSLIfStatement.h
@@ -32,7 +32,7 @@ struct IfStatement : public Statement {
return result;
}
- const std::unique_ptr<Expression> fTest;
+ std::unique_ptr<Expression> fTest;
const std::unique_ptr<Statement> fIfTrue;
const std::unique_ptr<Statement> fIfFalse;
diff --git a/src/sksl/ir/SkSLIndexExpression.h b/src/sksl/ir/SkSLIndexExpression.h
index 079dde5e53..d255c7daf6 100644
--- a/src/sksl/ir/SkSLIndexExpression.h
+++ b/src/sksl/ir/SkSLIndexExpression.h
@@ -55,8 +55,8 @@ struct IndexExpression : public Expression {
return fBase->description() + "[" + fIndex->description() + "]";
}
- const std::unique_ptr<Expression> fBase;
- const std::unique_ptr<Expression> fIndex;
+ std::unique_ptr<Expression> fBase;
+ std::unique_ptr<Expression> fIndex;
typedef Expression INHERITED;
};
diff --git a/src/sksl/ir/SkSLPostfixExpression.h b/src/sksl/ir/SkSLPostfixExpression.h
index 01671b5b88..6c9fafe5a0 100644
--- a/src/sksl/ir/SkSLPostfixExpression.h
+++ b/src/sksl/ir/SkSLPostfixExpression.h
@@ -25,7 +25,7 @@ struct PostfixExpression : public Expression {
return fOperand->description() + Token::OperatorName(fOperator);
}
- const std::unique_ptr<Expression> fOperand;
+ std::unique_ptr<Expression> fOperand;
const Token::Kind fOperator;
typedef Expression INHERITED;
diff --git a/src/sksl/ir/SkSLPrefixExpression.h b/src/sksl/ir/SkSLPrefixExpression.h
index 790c5ab47a..b7db99a0a4 100644
--- a/src/sksl/ir/SkSLPrefixExpression.h
+++ b/src/sksl/ir/SkSLPrefixExpression.h
@@ -25,7 +25,7 @@ struct PrefixExpression : public Expression {
return Token::OperatorName(fOperator) + fOperand->description();
}
- const std::unique_ptr<Expression> fOperand;
+ std::unique_ptr<Expression> fOperand;
const Token::Kind fOperator;
typedef Expression INHERITED;
diff --git a/src/sksl/ir/SkSLProgram.h b/src/sksl/ir/SkSLProgram.h
index ac49d6dcc7..6a73be6983 100644
--- a/src/sksl/ir/SkSLProgram.h
+++ b/src/sksl/ir/SkSLProgram.h
@@ -59,8 +59,8 @@ struct Program {
, fSettings(settings)
, fDefaultPrecision(defaultPrecision)
, fContext(context)
- , fElements(std::move(elements))
, fSymbols(symbols)
+ , fElements(std::move(elements))
, fInputs(inputs) {}
Kind fKind;
@@ -68,8 +68,10 @@ struct Program {
// FIXME handle different types; currently it assumes this is for floats
Modifiers::Flag fDefaultPrecision;
Context* fContext;
- std::vector<std::unique_ptr<ProgramElement>> fElements;
+ // it's important to keep fElements defined after (and thus destroyed before) fSymbols,
+ // because destroying elements can modify reference counts in symbols
std::shared_ptr<SymbolTable> fSymbols;
+ std::vector<std::unique_ptr<ProgramElement>> fElements;
Inputs fInputs;
};
diff --git a/src/sksl/ir/SkSLReturnStatement.h b/src/sksl/ir/SkSLReturnStatement.h
index c83b45066e..dc5ec9aa9c 100644
--- a/src/sksl/ir/SkSLReturnStatement.h
+++ b/src/sksl/ir/SkSLReturnStatement.h
@@ -32,7 +32,7 @@ struct ReturnStatement : public Statement {
}
}
- const std::unique_ptr<Expression> fExpression;
+ std::unique_ptr<Expression> fExpression;
typedef Statement INHERITED;
};
diff --git a/src/sksl/ir/SkSLSwizzle.h b/src/sksl/ir/SkSLSwizzle.h
index c9397aec7f..8ad9001ada 100644
--- a/src/sksl/ir/SkSLSwizzle.h
+++ b/src/sksl/ir/SkSLSwizzle.h
@@ -76,7 +76,7 @@ struct Swizzle : public Expression {
return result;
}
- const std::unique_ptr<Expression> fBase;
+ std::unique_ptr<Expression> fBase;
const std::vector<int> fComponents;
typedef Expression INHERITED;
diff --git a/src/sksl/ir/SkSLTernaryExpression.h b/src/sksl/ir/SkSLTernaryExpression.h
index 4a352536e3..02750049d4 100644
--- a/src/sksl/ir/SkSLTernaryExpression.h
+++ b/src/sksl/ir/SkSLTernaryExpression.h
@@ -31,9 +31,9 @@ struct TernaryExpression : public Expression {
fIfFalse->description() + ")";
}
- const std::unique_ptr<Expression> fTest;
- const std::unique_ptr<Expression> fIfTrue;
- const std::unique_ptr<Expression> fIfFalse;
+ std::unique_ptr<Expression> fTest;
+ std::unique_ptr<Expression> fIfTrue;
+ std::unique_ptr<Expression> fIfFalse;
typedef Expression INHERITED;
};
diff --git a/src/sksl/ir/SkSLVarDeclarations.h b/src/sksl/ir/SkSLVarDeclarations.h
index 295c0b6997..490259a081 100644
--- a/src/sksl/ir/SkSLVarDeclarations.h
+++ b/src/sksl/ir/SkSLVarDeclarations.h
@@ -72,7 +72,7 @@ struct VarDeclarations : public ProgramElement {
}
const Type& fBaseType;
- const std::vector<VarDeclaration> fVars;
+ std::vector<VarDeclaration> fVars;
typedef ProgramElement INHERITED;
};
diff --git a/src/sksl/ir/SkSLVarDeclarationsStatement.h b/src/sksl/ir/SkSLVarDeclarationsStatement.h
index 7a29656593..66b570f853 100644
--- a/src/sksl/ir/SkSLVarDeclarationsStatement.h
+++ b/src/sksl/ir/SkSLVarDeclarationsStatement.h
@@ -18,14 +18,14 @@ namespace SkSL {
*/
struct VarDeclarationsStatement : public Statement {
VarDeclarationsStatement(std::unique_ptr<VarDeclarations> decl)
- : INHERITED(decl->fPosition, kVarDeclarations_Kind)
+ : INHERITED(decl->fPosition, kVarDeclarations_Kind)
, fDeclaration(std::move(decl)) {}
SkString description() const override {
return fDeclaration->description();
}
- const std::shared_ptr<VarDeclarations> fDeclaration;
+ std::shared_ptr<VarDeclarations> fDeclaration;
typedef Statement INHERITED;
};
diff --git a/src/sksl/ir/SkSLVariable.h b/src/sksl/ir/SkSLVariable.h
index 39b8482a7b..2c3391dfa2 100644
--- a/src/sksl/ir/SkSLVariable.h
+++ b/src/sksl/ir/SkSLVariable.h
@@ -33,8 +33,8 @@ struct Variable : public Symbol {
, fModifiers(modifiers)
, fType(type)
, fStorage(storage)
- , fIsReadFrom(false)
- , fIsWrittenTo(false) {}
+ , fReadCount(0)
+ , fWriteCount(0) {}
virtual SkString description() const override {
return fModifiers.description() + fType.fName + " " + fName;
@@ -44,8 +44,12 @@ struct Variable : public Symbol {
const Type& fType;
const Storage fStorage;
- mutable bool fIsReadFrom;
- mutable bool fIsWrittenTo;
+ // Tracks how many sites read from the variable. If this is zero for a non-out variable (or
+ // becomes zero during optimization), the variable is dead and may be eliminated.
+ mutable int fReadCount;
+ // Tracks how many sites write to the variable. If this is zero, the variable is dead and may be
+ // eliminated.
+ mutable int fWriteCount;
typedef Symbol INHERITED;
};
diff --git a/src/sksl/ir/SkSLVariableReference.h b/src/sksl/ir/SkSLVariableReference.h
index c6a2ea0511..fecb04e2e5 100644
--- a/src/sksl/ir/SkSLVariableReference.h
+++ b/src/sksl/ir/SkSLVariableReference.h
@@ -20,16 +20,83 @@ namespace SkSL {
* there is only one Variable 'x', but two VariableReferences to it.
*/
struct VariableReference : public Expression {
- VariableReference(Position position, const Variable& variable)
+ enum RefKind {
+ kRead_RefKind,
+ kWrite_RefKind,
+ kReadWrite_RefKind
+ };
+
+ VariableReference(Position position, const Variable& variable, RefKind refKind = kRead_RefKind)
: INHERITED(position, kVariableReference_Kind, variable.fType)
- , fVariable(variable) {}
+ , fVariable(variable)
+ , fRefKind(refKind) {
+ if (refKind != kRead_RefKind) {
+ fVariable.fWriteCount++;
+ }
+ if (refKind != kWrite_RefKind) {
+ fVariable.fReadCount++;
+ }
+ }
+
+ virtual ~VariableReference() override {
+ if (fRefKind != kWrite_RefKind) {
+ fVariable.fReadCount--;
+ }
+ }
+
+ RefKind refKind() {
+ return fRefKind;
+ }
+
+ void setRefKind(RefKind refKind) {
+ if (fRefKind != kRead_RefKind) {
+ fVariable.fWriteCount--;
+ }
+ if (fRefKind != kWrite_RefKind) {
+ fVariable.fReadCount--;
+ }
+ if (refKind != kRead_RefKind) {
+ fVariable.fWriteCount++;
+ }
+ if (refKind != kWrite_RefKind) {
+ fVariable.fReadCount++;
+ }
+ fRefKind = refKind;
+ }
SkString description() const override {
return fVariable.fName;
}
+ virtual std::unique_ptr<Expression> constantPropagate(
+ const IRGenerator& irGenerator,
+ const DefinitionMap& definitions) override {
+ auto exprIter = definitions.find(&fVariable);
+ if (exprIter != definitions.end() && exprIter->second) {
+ const Expression* expr = exprIter->second->get();
+ switch (expr->fKind) {
+ case Expression::kIntLiteral_Kind:
+ return std::unique_ptr<Expression>(new IntLiteral(
+ irGenerator.fContext,
+ Position(),
+ ((IntLiteral*) expr)->fValue));
+ case Expression::kFloatLiteral_Kind:
+ return std::unique_ptr<Expression>(new FloatLiteral(
+ irGenerator.fContext,
+ Position(),
+ ((FloatLiteral*) expr)->fValue));
+ default:
+ break;
+ }
+ }
+ return nullptr;
+ }
+
const Variable& fVariable;
+private:
+ RefKind fRefKind;
+
typedef Expression INHERITED;
};
diff --git a/src/sksl/ir/SkSLWhileStatement.h b/src/sksl/ir/SkSLWhileStatement.h
index 7c6a2907c4..a741a0441d 100644
--- a/src/sksl/ir/SkSLWhileStatement.h
+++ b/src/sksl/ir/SkSLWhileStatement.h
@@ -27,7 +27,7 @@ struct WhileStatement : public Statement {
return "while (" + fTest->description() + ") " + fStatement->description();
}
- const std::unique_ptr<Expression> fTest;
+ std::unique_ptr<Expression> fTest;
const std::unique_ptr<Statement> fStatement;
typedef Statement INHERITED;