diff options
-rw-r--r-- | src/sksl/SkSLIRGenerator.cpp | 20 | ||||
-rw-r--r-- | src/sksl/ir/SkSLBoolLiteral.h | 5 | ||||
-rw-r--r-- | src/sksl/ir/SkSLConstructor.h | 83 | ||||
-rw-r--r-- | src/sksl/ir/SkSLExpression.h | 13 | ||||
-rw-r--r-- | src/sksl/ir/SkSLFloatLiteral.h | 5 | ||||
-rw-r--r-- | src/sksl/ir/SkSLIntLiteral.h | 5 | ||||
-rw-r--r-- | tests/SkSLErrorTest.cpp | 3 | ||||
-rw-r--r-- | tests/SkSLGLSLTest.cpp | 52 |
8 files changed, 186 insertions, 0 deletions
diff --git a/src/sksl/SkSLIRGenerator.cpp b/src/sksl/SkSLIRGenerator.cpp index 56858a92ca..523b7a009d 100644 --- a/src/sksl/SkSLIRGenerator.cpp +++ b/src/sksl/SkSLIRGenerator.cpp @@ -1043,6 +1043,12 @@ std::unique_ptr<Expression> IRGenerator::constantFold(const Expression& left, return std::unique_ptr<Expression>(new Constructor(Position(), left.fType, \ std::move(args))); switch (op) { + case Token::EQEQ: + return std::unique_ptr<Expression>(new BoolLiteral(fContext, Position(), + left.compareConstant(fContext, right))); + case Token::NEQ: + return std::unique_ptr<Expression>(new BoolLiteral(fContext, Position(), + !left.compareConstant(fContext, right))); case Token::PLUS: RETURN_VEC_COMPONENTWISE_RESULT(+); case Token::MINUS: RETURN_VEC_COMPONENTWISE_RESULT(-); case Token::STAR: RETURN_VEC_COMPONENTWISE_RESULT(*); @@ -1050,6 +1056,20 @@ std::unique_ptr<Expression> IRGenerator::constantFold(const Expression& left, default: return nullptr; } } + if (left.fType.kind() == Type::kMatrix_Kind && + right.fType.kind() == Type::kMatrix_Kind && + left.fKind == right.fKind) { + switch (op) { + case Token::EQEQ: + return std::unique_ptr<Expression>(new BoolLiteral(fContext, Position(), + left.compareConstant(fContext, right))); + case Token::NEQ: + return std::unique_ptr<Expression>(new BoolLiteral(fContext, Position(), + !left.compareConstant(fContext, right))); + default: + return nullptr; + } + } #undef RESULT return nullptr; } diff --git a/src/sksl/ir/SkSLBoolLiteral.h b/src/sksl/ir/SkSLBoolLiteral.h index 13203a4e55..a4151b8b35 100644 --- a/src/sksl/ir/SkSLBoolLiteral.h +++ b/src/sksl/ir/SkSLBoolLiteral.h @@ -33,6 +33,11 @@ struct BoolLiteral : public Expression { return true; } + bool compareConstant(const Context& context, const Expression& other) const override { + BoolLiteral& b = (BoolLiteral&) other; + return fValue == b.fValue; + } + const bool fValue; typedef Expression INHERITED; diff --git a/src/sksl/ir/SkSLConstructor.h b/src/sksl/ir/SkSLConstructor.h index 208031abba..05f409649a 100644 --- a/src/sksl/ir/SkSLConstructor.h +++ b/src/sksl/ir/SkSLConstructor.h @@ -81,6 +81,44 @@ struct Constructor : public Expression { return true; } + bool compareConstant(const Context& context, const Expression& other) const override { + ASSERT(other.fKind == Expression::kConstructor_Kind && other.fType == fType); + Constructor& c = (Constructor&) other; + if (c.fType.kind() == Type::kVector_Kind) { + for (int i = 0; i < fType.columns(); i++) { + if (!this->getVecComponent(i).compareConstant(context, c.getVecComponent(i))) { + return false; + } + } + return true; + } + // shouldn't be possible to have a constant constructor that isn't a vector or matrix; + // a constant scalar constructor should have been collapsed down to the appropriate + // literal + ASSERT(fType.kind() == Type::kMatrix_Kind); + const FloatLiteral fzero(context, Position(), 0); + const IntLiteral izero(context, Position(), 0); + const Expression* zero; + if (fType.componentType() == *context.fFloat_Type) { + zero = &fzero; + } else { + ASSERT(fType.componentType() == *context.fInt_Type); + zero = &izero; + } + for (int col = 0; col < fType.columns(); col++) { + for (int row = 0; row < fType.rows(); row++) { + const Expression* component1 = getMatComponent(col, row); + const Expression* component2 = c.getMatComponent(col, row); + if (!(component1 ? component1 : zero)->compareConstant( + context, + component2 ? *component2 : *zero)) { + return false; + } + } + } + return true; + } + const Expression& getVecComponent(int index) const { ASSERT(fType.kind() == Type::kVector_Kind); if (fArguments.size() == 1 && fArguments[0]->fType.kind() == Type::kScalar_Kind) { @@ -118,6 +156,51 @@ struct Constructor : public Expression { return ((IntLiteral&) c).fValue; } + // null return should be interpreted as zero + const Expression* getMatComponent(int col, int row) const { + ASSERT(this->isConstant()); + ASSERT(fType.kind() == Type::kMatrix_Kind); + ASSERT(col < fType.columns() && row < fType.rows()); + if (fArguments.size() == 1) { + if (fArguments[0]->fType.kind() == Type::kScalar_Kind) { + // single scalar argument, so matrix is of the form: + // x 0 0 + // 0 x 0 + // 0 0 x + // return x if col == row + return col == row ? fArguments[0].get() : nullptr; + } + if (fArguments[0]->fType.kind() == Type::kMatrix_Kind) { + ASSERT(fArguments[0]->fKind == Expression::kConstructor_Kind); + // single matrix argument. make sure we're within the argument's bounds. + const Type& argType = ((Constructor&) *fArguments[0]).fType; + if (col < argType.columns() && row < argType.rows()) { + // within bounds, defer to argument + return ((Constructor&) *fArguments[0]).getMatComponent(col, row); + } + // out of bounds, return 0 + return nullptr; + } + } + int currentIndex = 0; + int targetIndex = col * fType.rows() + row; + for (const auto& arg : fArguments) { + ASSERT(targetIndex >= currentIndex); + ASSERT(arg->fType.rows() == 1); + if (currentIndex + arg->fType.columns() > targetIndex) { + if (arg->fType.columns() == 1) { + return arg.get(); + } else { + ASSERT(arg->fType.kind() == Type::kVector_Kind); + ASSERT(arg->fKind == Expression::kConstructor_Kind); + return &((Constructor&) *arg).getVecComponent(targetIndex - currentIndex); + } + } + currentIndex += arg->fType.columns(); + } + ABORT("can't happen, matrix component out of bounds"); + } + std::vector<std::unique_ptr<Expression>> fArguments; typedef Expression INHERITED; diff --git a/src/sksl/ir/SkSLExpression.h b/src/sksl/ir/SkSLExpression.h index 5db9ddf96f..07dad1d7df 100644 --- a/src/sksl/ir/SkSLExpression.h +++ b/src/sksl/ir/SkSLExpression.h @@ -48,11 +48,24 @@ struct Expression : public IRNode { , fKind(kind) , fType(std::move(type)) {} + /** + * Returns true if this expression is constant. compareConstant must be implemented for all + * constants! + */ virtual bool isConstant() const { return false; } /** + * Compares this constant expression against another constant expression of the same type. It is + * an error to call this on non-constant expressions, or if the types of the expressions do not + * match. + */ + virtual bool compareConstant(const Context& context, const Expression& other) const { + ABORT("cannot call compareConstant on this type"); + } + + /** * Returns true if evaluating the expression potentially has side effects. Expressions may never * return false if they actually have side effects, but it is legal (though suboptimal) to * return true if there are not actually any side effects. diff --git a/src/sksl/ir/SkSLFloatLiteral.h b/src/sksl/ir/SkSLFloatLiteral.h index 8f83e2866c..21a485fb0a 100644 --- a/src/sksl/ir/SkSLFloatLiteral.h +++ b/src/sksl/ir/SkSLFloatLiteral.h @@ -34,6 +34,11 @@ struct FloatLiteral : public Expression { return true; } + bool compareConstant(const Context& context, const Expression& other) const override { + FloatLiteral& f = (FloatLiteral&) other; + return fValue == f.fValue; + } + const double fValue; typedef Expression INHERITED; diff --git a/src/sksl/ir/SkSLIntLiteral.h b/src/sksl/ir/SkSLIntLiteral.h index 3a95ed65ba..d8eba5573a 100644 --- a/src/sksl/ir/SkSLIntLiteral.h +++ b/src/sksl/ir/SkSLIntLiteral.h @@ -35,6 +35,11 @@ struct IntLiteral : public Expression { return true; } + bool compareConstant(const Context& context, const Expression& other) const override { + IntLiteral& i = (IntLiteral&) other; + return fValue == i.fValue; + } + const int64_t fValue; typedef Expression INHERITED; diff --git a/tests/SkSLErrorTest.cpp b/tests/SkSLErrorTest.cpp index bd0c64a93b..47b9af8ee3 100644 --- a/tests/SkSLErrorTest.cpp +++ b/tests/SkSLErrorTest.cpp @@ -125,6 +125,9 @@ DEF_TEST(SkSLConstructorTypeMismatch, r) { test_failure(r, "struct foo { int x; } foo; void main() { vec2 x = vec2(foo); }", "error: 1: 'foo' is not a valid parameter to 'vec2' constructor\n1 error\n"); + test_failure(r, + "void main() { mat2 x = mat2(true); }", + "error: 1: expected 'float', but found 'bool'\n1 error\n"); } DEF_TEST(SkSLConstructorArgumentCount, r) { diff --git a/tests/SkSLGLSLTest.cpp b/tests/SkSLGLSLTest.cpp index 97d7acbce9..1dc522b07d 100644 --- a/tests/SkSLGLSLTest.cpp +++ b/tests/SkSLGLSLTest.cpp @@ -573,6 +573,34 @@ DEF_TEST(SkSLConstantFolding, r) { "sk_FragColor = vec4(2) * vec4(1, 2, 3, 4);" "sk_FragColor = vec4(12) / vec4(1, 2, 3, 4);" "sk_FragColor.r = (vec4(12) / vec4(1, 2, 3, 4)).y;" + "sk_FragColor.x = vec4(1) == vec4(1) ? 1.0 : 0.0;" + "sk_FragColor.x = vec4(1) == vec4(2) ? 1.0 : 0.0;" + "sk_FragColor.x = vec2(1) == vec2(1, 1) ? 1.0 : 0.0;" + "sk_FragColor.x = vec2(1, 1) == vec2(1, 1) ? 1.0 : 0.0;" + "sk_FragColor.x = vec2(1) == vec2(1, 0) ? 1.0 : 0.0;" + "sk_FragColor.x = vec4(1) == vec4(vec2(1), vec2(1)) ? 1.0 : 0.0;" + "sk_FragColor.x = vec4(vec3(1), 1) == vec4(vec2(1), vec2(1)) ? 1.0 : 0.0;" + "sk_FragColor.x = vec4(vec3(1), 1) == vec4(vec2(1), 1, 0) ? 1.0 : 0.0;" + "sk_FragColor.x = mat2(vec2(1.0, 0.0), vec2(0.0, 1.0)) == " + "mat2(vec2(1.0, 0.0), vec2(0.0, 1.0)) ? 1.0 : 0.0;" + "sk_FragColor.x = mat2(vec2(1.0, 0.0), vec2(1.0, 1.0)) == " + "mat2(vec2(1.0, 0.0), vec2(0.0, 1.0)) ? 1.0 : 0.0;" + "sk_FragColor.x = mat2(1) == mat2(1) ? 1.0 : 0.0;" + "sk_FragColor.x = mat2(1) == mat2(0) ? 1.0 : 0.0;" + "sk_FragColor.x = mat2(1) == mat2(vec2(1.0, 0.0), vec2(0.0, 1.0)) ? 1.0 : 0.0;" + "sk_FragColor.x = mat2(2) == mat2(vec2(1.0, 0.0), vec2(0.0, 1.0)) ? 1.0 : 0.0;" + "sk_FragColor.x = mat3x2(2) == mat3x2(vec2(2.0, 0.0), vec2(0.0, 2.0), vec2(0.0)) ? " + "1.0 : 0.0;" + "sk_FragColor.x = vec2(1) != vec2(1, 0) ? 1.0 : 0.0;" + "sk_FragColor.x = vec4(1) != vec4(vec2(1), vec2(1)) ? 1.0 : 0.0;" + "sk_FragColor.x = mat2(1) != mat2(1) ? 1.0 : 0.0;" + "sk_FragColor.x = mat2(1) != mat2(0) ? 1.0 : 0.0;" + "sk_FragColor.x = mat3(vec3(1.0, 0.0, 0.0), vec3(0.0, 1.0, 0.0), vec3(0.0, 0.0, 0.0)) == " + "mat3(mat2(1.0)) ? 1.0 : 0.0;" + "sk_FragColor.x = mat2(mat3(1.0)) == mat2(1.0) ? 1.0 : 0.0;" + "sk_FragColor.x = mat2(vec4(1.0, 0.0, 0.0, 1.0)) == mat2(1.0) ? 1.0 : 0.0;" + "sk_FragColor.x = mat2(1.0, 0.0, vec2(0.0, 1.0)) == mat2(1.0) ? 1.0 : 0.0;" + "sk_FragColor.x = mat2(vec2(1.0, 0.0), 0.0, 1.0) == mat2(1.0) ? 1.0 : 0.0;" "}", *SkSL::ShaderCapsFactory::Default(), "#version 400\n" @@ -617,6 +645,30 @@ DEF_TEST(SkSLConstantFolding, r) { " sk_FragColor = vec4(2.0, 4.0, 6.0, 8.0);\n" " sk_FragColor = vec4(12.0, 6.0, 4.0, 3.0);\n" " sk_FragColor.x = 6.0;\n" + " sk_FragColor.x = 1.0;\n" + " sk_FragColor.x = 0.0;\n" + " sk_FragColor.x = 1.0;\n" + " sk_FragColor.x = 1.0;\n" + " sk_FragColor.x = 0.0;\n" + " sk_FragColor.x = 1.0;\n" + " sk_FragColor.x = 1.0;\n" + " sk_FragColor.x = 0.0;\n" + " sk_FragColor.x = 1.0;\n" + " sk_FragColor.x = 0.0;\n" + " sk_FragColor.x = 1.0;\n" + " sk_FragColor.x = 0.0;\n" + " sk_FragColor.x = 1.0;\n" + " sk_FragColor.x = 0.0;\n" + " sk_FragColor.x = 1.0;\n" + " sk_FragColor.x = 1.0;\n" + " sk_FragColor.x = 0.0;\n" + " sk_FragColor.x = 0.0;\n" + " sk_FragColor.x = 1.0;\n" + " sk_FragColor.x = 1.0;\n" + " sk_FragColor.x = 1.0;\n" + " sk_FragColor.x = 1.0;\n" + " sk_FragColor.x = 1.0;\n" + " sk_FragColor.x = 1.0;\n" "}\n"); } |