aboutsummaryrefslogtreecommitdiffhomepage
path: root/src/sksl
diff options
context:
space:
mode:
authorGravatar Ethan Nicholas <ethannicholas@google.com>2018-07-10 09:37:51 -0400
committerGravatar Skia Commit-Bot <skia-commit-bot@chromium.org>2018-07-13 17:27:44 +0000
commit0df21136e3dc5434149695f09b79aff271430365 (patch)
treef55f26ffa74e3fddb057488ba5d162603c56d069 /src/sksl
parented1205ae20b213c07305d604b2b515ab27cba085 (diff)
fixed SPIR-V matrix operations
Bug: skia: Change-Id: I23be824cdd7d00ffd0c54516a168c07e77bb4f49 Reviewed-on: https://skia-review.googlesource.com/140182 Reviewed-by: Greg Daniel <egdaniel@google.com> Commit-Queue: Ethan Nicholas <ethannicholas@google.com>
Diffstat (limited to 'src/sksl')
-rw-r--r--src/sksl/SkSLSPIRVCodeGenerator.cpp97
-rw-r--r--src/sksl/SkSLSPIRVCodeGenerator.h7
2 files changed, 84 insertions, 20 deletions
diff --git a/src/sksl/SkSLSPIRVCodeGenerator.cpp b/src/sksl/SkSLSPIRVCodeGenerator.cpp
index c8e1255352..9cc933a921 100644
--- a/src/sksl/SkSLSPIRVCodeGenerator.cpp
+++ b/src/sksl/SkSLSPIRVCodeGenerator.cpp
@@ -143,7 +143,7 @@ void SPIRVCodeGenerator::writeWord(int32_t word, OutputStream& out) {
}
static bool is_float(const Context& context, const Type& type) {
- if (type.kind() == Type::kVector_Kind) {
+ if (type.columns() > 1) {
return is_float(context, type.componentType());
}
return type == *context.fFloat_Type || type == *context.fHalf_Type ||
@@ -1822,38 +1822,67 @@ SpvId SPIRVCodeGenerator::foldToBool(SpvId id, const Type& operandType, SpvOp op
SpvId SPIRVCodeGenerator::writeMatrixComparison(const Type& operandType, SpvId lhs, SpvId rhs,
SpvOp_ floatOperator, SpvOp_ intOperator,
+ SpvOp_ vectorMergeOperator, SpvOp_ mergeOperator,
OutputStream& out) {
SpvOp_ compareOp = is_float(fContext, operandType) ? floatOperator : intOperator;
SkASSERT(operandType.kind() == Type::kMatrix_Kind);
- SpvId rowType = this->getType(operandType.componentType().toCompound(fContext,
- operandType.columns(),
- 1));
+ SpvId columnType = this->getType(operandType.componentType().toCompound(fContext,
+ operandType.rows(),
+ 1));
SpvId bvecType = this->getType(fContext.fBool_Type->toCompound(fContext,
- operandType.columns(),
+ operandType.rows(),
1));
SpvId boolType = this->getType(*fContext.fBool_Type);
SpvId result = 0;
- for (int i = 0; i < operandType.rows(); i++) {
- SpvId rowL = this->nextId();
- this->writeInstruction(SpvOpCompositeExtract, rowType, rowL, lhs, 0, out);
- SpvId rowR = this->nextId();
- this->writeInstruction(SpvOpCompositeExtract, rowType, rowR, rhs, 0, out);
+ for (int i = 0; i < operandType.columns(); i++) {
+ SpvId columnL = this->nextId();
+ this->writeInstruction(SpvOpCompositeExtract, columnType, columnL, lhs, i, out);
+ SpvId columnR = this->nextId();
+ this->writeInstruction(SpvOpCompositeExtract, columnType, columnR, rhs, i, out);
SpvId compare = this->nextId();
- this->writeInstruction(compareOp, bvecType, compare, rowL, rowR, out);
- SpvId all = this->nextId();
- this->writeInstruction(SpvOpAll, boolType, all, compare, out);
+ this->writeInstruction(compareOp, bvecType, compare, columnL, columnR, out);
+ SpvId merge = this->nextId();
+ this->writeInstruction(vectorMergeOperator, boolType, merge, compare, out);
if (result != 0) {
SpvId next = this->nextId();
- this->writeInstruction(SpvOpLogicalAnd, boolType, next, result, all, out);
+ this->writeInstruction(mergeOperator, boolType, next, result, merge, out);
result = next;
}
else {
- result = all;
+ result = merge;
}
}
return result;
}
+SpvId SPIRVCodeGenerator::writeComponentwiseMatrixBinary(const Type& operandType, SpvId lhs,
+ SpvId rhs, SpvOp_ floatOperator,
+ SpvOp_ intOperator,
+ OutputStream& out) {
+ SpvOp_ op = is_float(fContext, operandType) ? floatOperator : intOperator;
+ SkASSERT(operandType.kind() == Type::kMatrix_Kind);
+ SpvId columnType = this->getType(operandType.componentType().toCompound(fContext,
+ operandType.rows(),
+ 1));
+ SpvId columns[4];
+ for (int i = 0; i < operandType.columns(); i++) {
+ SpvId columnL = this->nextId();
+ this->writeInstruction(SpvOpCompositeExtract, columnType, columnL, lhs, i, out);
+ SpvId columnR = this->nextId();
+ this->writeInstruction(SpvOpCompositeExtract, columnType, columnR, rhs, i, out);
+ columns[i] = this->nextId();
+ this->writeInstruction(op, columnType, columns[i], columnL, columnR, out);
+ }
+ SpvId result = this->nextId();
+ this->writeOpCode(SpvOpCompositeConstruct, 3 + operandType.columns(), out);
+ this->writeWord(this->getType(operandType), out);
+ this->writeWord(result, out);
+ for (int i = 0; i < operandType.columns(); i++) {
+ this->writeWord(columns[i], out);
+ }
+ return result;
+}
+
SpvId SPIRVCodeGenerator::writeBinaryExpression(const BinaryExpression& b, OutputStream& out) {
// handle cases where we don't necessarily evaluate both LHS and RHS
switch (b.fOperator) {
@@ -1964,7 +1993,7 @@ SpvId SPIRVCodeGenerator::writeBinaryExpression(const BinaryExpression& b, Outpu
case Token::EQEQ: {
if (operandType->kind() == Type::kMatrix_Kind) {
return this->writeMatrixComparison(*operandType, lhs, rhs, SpvOpFOrdEqual,
- SpvOpIEqual, out);
+ SpvOpIEqual, SpvOpAll, SpvOpLogicalAnd, out);
}
SkASSERT(resultType == *fContext.fBool_Type);
const Type* tmpType;
@@ -1983,7 +2012,7 @@ SpvId SPIRVCodeGenerator::writeBinaryExpression(const BinaryExpression& b, Outpu
case Token::NEQ:
if (operandType->kind() == Type::kMatrix_Kind) {
return this->writeMatrixComparison(*operandType, lhs, rhs, SpvOpFOrdNotEqual,
- SpvOpINotEqual, out);
+ SpvOpINotEqual, SpvOpAny, SpvOpLogicalOr, out);
}
SkASSERT(resultType == *fContext.fBool_Type);
const Type* tmpType;
@@ -2019,9 +2048,21 @@ SpvId SPIRVCodeGenerator::writeBinaryExpression(const BinaryExpression& b, Outpu
SpvOpFOrdLessThanEqual, SpvOpSLessThanEqual,
SpvOpULessThanEqual, SpvOpUndef, out);
case Token::PLUS:
+ if (b.fLeft->fType.kind() == Type::kMatrix_Kind &&
+ b.fRight->fType.kind() == Type::kMatrix_Kind) {
+ SkASSERT(b.fLeft->fType == b.fRight->fType);
+ return this->writeComponentwiseMatrixBinary(b.fLeft->fType, lhs, rhs,
+ SpvOpFAdd, SpvOpIAdd, out);
+ }
return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpFAdd,
SpvOpIAdd, SpvOpIAdd, SpvOpUndef, out);
case Token::MINUS:
+ if (b.fLeft->fType.kind() == Type::kMatrix_Kind &&
+ b.fRight->fType.kind() == Type::kMatrix_Kind) {
+ SkASSERT(b.fLeft->fType == b.fRight->fType);
+ return this->writeComponentwiseMatrixBinary(b.fLeft->fType, lhs, rhs,
+ SpvOpFSub, SpvOpISub, out);
+ }
return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpFSub,
SpvOpISub, SpvOpISub, SpvOpUndef, out);
case Token::STAR:
@@ -2059,15 +2100,33 @@ SpvId SPIRVCodeGenerator::writeBinaryExpression(const BinaryExpression& b, Outpu
return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpUndef,
SpvOpBitwiseXor, SpvOpBitwiseXor, SpvOpUndef, out);
case Token::PLUSEQ: {
- SpvId result = this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpFAdd,
+ SpvId result;
+ if (b.fLeft->fType.kind() == Type::kMatrix_Kind &&
+ b.fRight->fType.kind() == Type::kMatrix_Kind) {
+ SkASSERT(b.fLeft->fType == b.fRight->fType);
+ result = this->writeComponentwiseMatrixBinary(b.fLeft->fType, lhs, rhs,
+ SpvOpFAdd, SpvOpIAdd, out);
+ }
+ else {
+ result = this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpFAdd,
SpvOpIAdd, SpvOpIAdd, SpvOpUndef, out);
+ }
SkASSERT(lvalue);
lvalue->store(result, out);
return result;
}
case Token::MINUSEQ: {
- SpvId result = this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpFSub,
+ SpvId result;
+ if (b.fLeft->fType.kind() == Type::kMatrix_Kind &&
+ b.fRight->fType.kind() == Type::kMatrix_Kind) {
+ SkASSERT(b.fLeft->fType == b.fRight->fType);
+ result = this->writeComponentwiseMatrixBinary(b.fLeft->fType, lhs, rhs,
+ SpvOpFSub, SpvOpISub, out);
+ }
+ else {
+ result = this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpFSub,
SpvOpISub, SpvOpISub, SpvOpUndef, out);
+ }
SkASSERT(lvalue);
lvalue->store(result, out);
return result;
diff --git a/src/sksl/SkSLSPIRVCodeGenerator.h b/src/sksl/SkSLSPIRVCodeGenerator.h
index fee54ad65c..16f5beb68f 100644
--- a/src/sksl/SkSLSPIRVCodeGenerator.h
+++ b/src/sksl/SkSLSPIRVCodeGenerator.h
@@ -211,7 +211,12 @@ private:
SpvId foldToBool(SpvId id, const Type& operandType, SpvOp op, OutputStream& out);
SpvId writeMatrixComparison(const Type& operandType, SpvId lhs, SpvId rhs, SpvOp_ floatOperator,
- SpvOp_ intOperator, OutputStream& out);
+ SpvOp_ intOperator, SpvOp_ vectorMergeOperator,
+ SpvOp_ mergeOperator, OutputStream& out);
+
+ SpvId writeComponentwiseMatrixBinary(const Type& operandType, SpvId lhs, SpvId rhs,
+ SpvOp_ floatOperator, SpvOp_ intOperator,
+ OutputStream& out);
SpvId writeBinaryOperation(const Type& resultType, const Type& operandType, SpvId lhs,
SpvId rhs, SpvOp_ ifFloat, SpvOp_ ifInt, SpvOp_ ifUInt,