From 0df21136e3dc5434149695f09b79aff271430365 Mon Sep 17 00:00:00 2001 From: Ethan Nicholas Date: Tue, 10 Jul 2018 09:37:51 -0400 Subject: fixed SPIR-V matrix operations Bug: skia: Change-Id: I23be824cdd7d00ffd0c54516a168c07e77bb4f49 Reviewed-on: https://skia-review.googlesource.com/140182 Reviewed-by: Greg Daniel Commit-Queue: Ethan Nicholas --- src/sksl/SkSLSPIRVCodeGenerator.cpp | 97 +++++++++++++++++++++++++++++-------- src/sksl/SkSLSPIRVCodeGenerator.h | 7 ++- 2 files changed, 84 insertions(+), 20 deletions(-) (limited to 'src/sksl') diff --git a/src/sksl/SkSLSPIRVCodeGenerator.cpp b/src/sksl/SkSLSPIRVCodeGenerator.cpp index c8e1255352..9cc933a921 100644 --- a/src/sksl/SkSLSPIRVCodeGenerator.cpp +++ b/src/sksl/SkSLSPIRVCodeGenerator.cpp @@ -143,7 +143,7 @@ void SPIRVCodeGenerator::writeWord(int32_t word, OutputStream& out) { } static bool is_float(const Context& context, const Type& type) { - if (type.kind() == Type::kVector_Kind) { + if (type.columns() > 1) { return is_float(context, type.componentType()); } return type == *context.fFloat_Type || type == *context.fHalf_Type || @@ -1822,38 +1822,67 @@ SpvId SPIRVCodeGenerator::foldToBool(SpvId id, const Type& operandType, SpvOp op SpvId SPIRVCodeGenerator::writeMatrixComparison(const Type& operandType, SpvId lhs, SpvId rhs, SpvOp_ floatOperator, SpvOp_ intOperator, + SpvOp_ vectorMergeOperator, SpvOp_ mergeOperator, OutputStream& out) { SpvOp_ compareOp = is_float(fContext, operandType) ? floatOperator : intOperator; SkASSERT(operandType.kind() == Type::kMatrix_Kind); - SpvId rowType = this->getType(operandType.componentType().toCompound(fContext, - operandType.columns(), - 1)); + SpvId columnType = this->getType(operandType.componentType().toCompound(fContext, + operandType.rows(), + 1)); SpvId bvecType = this->getType(fContext.fBool_Type->toCompound(fContext, - operandType.columns(), + operandType.rows(), 1)); SpvId boolType = this->getType(*fContext.fBool_Type); SpvId result = 0; - for (int i = 0; i < operandType.rows(); i++) { - SpvId rowL = this->nextId(); - this->writeInstruction(SpvOpCompositeExtract, rowType, rowL, lhs, 0, out); - SpvId rowR = this->nextId(); - this->writeInstruction(SpvOpCompositeExtract, rowType, rowR, rhs, 0, out); + for (int i = 0; i < operandType.columns(); i++) { + SpvId columnL = this->nextId(); + this->writeInstruction(SpvOpCompositeExtract, columnType, columnL, lhs, i, out); + SpvId columnR = this->nextId(); + this->writeInstruction(SpvOpCompositeExtract, columnType, columnR, rhs, i, out); SpvId compare = this->nextId(); - this->writeInstruction(compareOp, bvecType, compare, rowL, rowR, out); - SpvId all = this->nextId(); - this->writeInstruction(SpvOpAll, boolType, all, compare, out); + this->writeInstruction(compareOp, bvecType, compare, columnL, columnR, out); + SpvId merge = this->nextId(); + this->writeInstruction(vectorMergeOperator, boolType, merge, compare, out); if (result != 0) { SpvId next = this->nextId(); - this->writeInstruction(SpvOpLogicalAnd, boolType, next, result, all, out); + this->writeInstruction(mergeOperator, boolType, next, result, merge, out); result = next; } else { - result = all; + result = merge; } } return result; } +SpvId SPIRVCodeGenerator::writeComponentwiseMatrixBinary(const Type& operandType, SpvId lhs, + SpvId rhs, SpvOp_ floatOperator, + SpvOp_ intOperator, + OutputStream& out) { + SpvOp_ op = is_float(fContext, operandType) ? floatOperator : intOperator; + SkASSERT(operandType.kind() == Type::kMatrix_Kind); + SpvId columnType = this->getType(operandType.componentType().toCompound(fContext, + operandType.rows(), + 1)); + SpvId columns[4]; + for (int i = 0; i < operandType.columns(); i++) { + SpvId columnL = this->nextId(); + this->writeInstruction(SpvOpCompositeExtract, columnType, columnL, lhs, i, out); + SpvId columnR = this->nextId(); + this->writeInstruction(SpvOpCompositeExtract, columnType, columnR, rhs, i, out); + columns[i] = this->nextId(); + this->writeInstruction(op, columnType, columns[i], columnL, columnR, out); + } + SpvId result = this->nextId(); + this->writeOpCode(SpvOpCompositeConstruct, 3 + operandType.columns(), out); + this->writeWord(this->getType(operandType), out); + this->writeWord(result, out); + for (int i = 0; i < operandType.columns(); i++) { + this->writeWord(columns[i], out); + } + return result; +} + SpvId SPIRVCodeGenerator::writeBinaryExpression(const BinaryExpression& b, OutputStream& out) { // handle cases where we don't necessarily evaluate both LHS and RHS switch (b.fOperator) { @@ -1964,7 +1993,7 @@ SpvId SPIRVCodeGenerator::writeBinaryExpression(const BinaryExpression& b, Outpu case Token::EQEQ: { if (operandType->kind() == Type::kMatrix_Kind) { return this->writeMatrixComparison(*operandType, lhs, rhs, SpvOpFOrdEqual, - SpvOpIEqual, out); + SpvOpIEqual, SpvOpAll, SpvOpLogicalAnd, out); } SkASSERT(resultType == *fContext.fBool_Type); const Type* tmpType; @@ -1983,7 +2012,7 @@ SpvId SPIRVCodeGenerator::writeBinaryExpression(const BinaryExpression& b, Outpu case Token::NEQ: if (operandType->kind() == Type::kMatrix_Kind) { return this->writeMatrixComparison(*operandType, lhs, rhs, SpvOpFOrdNotEqual, - SpvOpINotEqual, out); + SpvOpINotEqual, SpvOpAny, SpvOpLogicalOr, out); } SkASSERT(resultType == *fContext.fBool_Type); const Type* tmpType; @@ -2019,9 +2048,21 @@ SpvId SPIRVCodeGenerator::writeBinaryExpression(const BinaryExpression& b, Outpu SpvOpFOrdLessThanEqual, SpvOpSLessThanEqual, SpvOpULessThanEqual, SpvOpUndef, out); case Token::PLUS: + if (b.fLeft->fType.kind() == Type::kMatrix_Kind && + b.fRight->fType.kind() == Type::kMatrix_Kind) { + SkASSERT(b.fLeft->fType == b.fRight->fType); + return this->writeComponentwiseMatrixBinary(b.fLeft->fType, lhs, rhs, + SpvOpFAdd, SpvOpIAdd, out); + } return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpFAdd, SpvOpIAdd, SpvOpIAdd, SpvOpUndef, out); case Token::MINUS: + if (b.fLeft->fType.kind() == Type::kMatrix_Kind && + b.fRight->fType.kind() == Type::kMatrix_Kind) { + SkASSERT(b.fLeft->fType == b.fRight->fType); + return this->writeComponentwiseMatrixBinary(b.fLeft->fType, lhs, rhs, + SpvOpFSub, SpvOpISub, out); + } return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpFSub, SpvOpISub, SpvOpISub, SpvOpUndef, out); case Token::STAR: @@ -2059,15 +2100,33 @@ SpvId SPIRVCodeGenerator::writeBinaryExpression(const BinaryExpression& b, Outpu return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpUndef, SpvOpBitwiseXor, SpvOpBitwiseXor, SpvOpUndef, out); case Token::PLUSEQ: { - SpvId result = this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpFAdd, + SpvId result; + if (b.fLeft->fType.kind() == Type::kMatrix_Kind && + b.fRight->fType.kind() == Type::kMatrix_Kind) { + SkASSERT(b.fLeft->fType == b.fRight->fType); + result = this->writeComponentwiseMatrixBinary(b.fLeft->fType, lhs, rhs, + SpvOpFAdd, SpvOpIAdd, out); + } + else { + result = this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpFAdd, SpvOpIAdd, SpvOpIAdd, SpvOpUndef, out); + } SkASSERT(lvalue); lvalue->store(result, out); return result; } case Token::MINUSEQ: { - SpvId result = this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpFSub, + SpvId result; + if (b.fLeft->fType.kind() == Type::kMatrix_Kind && + b.fRight->fType.kind() == Type::kMatrix_Kind) { + SkASSERT(b.fLeft->fType == b.fRight->fType); + result = this->writeComponentwiseMatrixBinary(b.fLeft->fType, lhs, rhs, + SpvOpFSub, SpvOpISub, out); + } + else { + result = this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpFSub, SpvOpISub, SpvOpISub, SpvOpUndef, out); + } SkASSERT(lvalue); lvalue->store(result, out); return result; diff --git a/src/sksl/SkSLSPIRVCodeGenerator.h b/src/sksl/SkSLSPIRVCodeGenerator.h index fee54ad65c..16f5beb68f 100644 --- a/src/sksl/SkSLSPIRVCodeGenerator.h +++ b/src/sksl/SkSLSPIRVCodeGenerator.h @@ -211,7 +211,12 @@ private: SpvId foldToBool(SpvId id, const Type& operandType, SpvOp op, OutputStream& out); SpvId writeMatrixComparison(const Type& operandType, SpvId lhs, SpvId rhs, SpvOp_ floatOperator, - SpvOp_ intOperator, OutputStream& out); + SpvOp_ intOperator, SpvOp_ vectorMergeOperator, + SpvOp_ mergeOperator, OutputStream& out); + + SpvId writeComponentwiseMatrixBinary(const Type& operandType, SpvId lhs, SpvId rhs, + SpvOp_ floatOperator, SpvOp_ intOperator, + OutputStream& out); SpvId writeBinaryOperation(const Type& resultType, const Type& operandType, SpvId lhs, SpvId rhs, SpvOp_ ifFloat, SpvOp_ ifInt, SpvOp_ ifUInt, -- cgit v1.2.3