| //===- CooperativeMatrixOps.cpp - MLIR SPIR-V Cooperative Matrix Ops -----===// |
| // |
| // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| // See https://llvm.org/LICENSE.txt for license information. |
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| // |
| //===----------------------------------------------------------------------===// |
| // |
| // Defines the Cooperative Matrix operations in the SPIR-V dialect. |
| // |
| //===----------------------------------------------------------------------===// |
| |
| #include "SPIRVParsingUtils.h" |
| #include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h" |
| #include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h" |
| #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" |
| #include "llvm/ADT/STLExtras.h" |
| |
| using namespace mlir::spirv::AttrNames; |
| |
| namespace mlir::spirv { |
| |
| static LogicalResult |
| verifyCoopMatrixAccess(Operation *op, Type pointer, Type coopMatrix, |
| spirv::MemoryAccessAttr memoryOperand, |
| IntegerAttr alignment) { |
| auto pointerType = cast<PointerType>(pointer); |
| Type pointeeType = pointerType.getPointeeType(); |
| if (!isa<ScalarType, VectorType>(pointeeType)) { |
| return op->emitOpError( |
| "Pointer must point to a scalar or vector type but provided ") |
| << pointeeType; |
| } |
| |
| if (memoryOperand) { |
| spirv::MemoryAccess operandSet = memoryOperand.getValue(); |
| |
| if (isa<spirv::KHRCooperativeMatrixLoadOp>(op) && |
| spirv::bitEnumContainsAll(operandSet, |
| spirv::MemoryAccess::MakePointerAvailable)) { |
| return op->emitOpError( |
| "not compatible with memory operand 'MakePointerAvailable'"); |
| } |
| |
| if (isa<spirv::KHRCooperativeMatrixStoreOp>(op) && |
| spirv::bitEnumContainsAll(operandSet, |
| spirv::MemoryAccess::MakePointerVisible)) { |
| return op->emitOpError( |
| "not compatible with memory operand 'MakePointerVisible'"); |
| } |
| |
| // TODO: Need to check that NonPrivatePointer is set for MakePointer*. See |
| // #145485. |
| |
| if (spirv::bitEnumContainsAll(operandSet, spirv::MemoryAccess::Aligned) && |
| !alignment) { |
| return op->emitOpError("missing value for the 'Aligned' memory operand"); |
| } |
| |
| if (!spirv::bitEnumContainsAll(operandSet, spirv::MemoryAccess::Aligned) && |
| alignment) { |
| return op->emitOpError( |
| "found alignment attribute for non-'Aligned' memory operand"); |
| } |
| } |
| |
| // TODO: Verify the memory object behind the pointer: |
| // > If the Shader capability was declared, Pointer must point into an array |
| // > and any ArrayStride decoration on Pointer is ignored. |
| |
| return success(); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // spirv.KHR.CooperativeMatrixLoad |
| //===----------------------------------------------------------------------===// |
| |
| LogicalResult KHRCooperativeMatrixLoadOp::verify() { |
| return verifyCoopMatrixAccess(*this, getPointer().getType(), |
| getResult().getType(), getMemoryOperandAttr(), |
| getAlignmentAttr()); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // spirv.KHR.CooperativeMatrixStore |
| //===----------------------------------------------------------------------===// |
| |
| LogicalResult KHRCooperativeMatrixStoreOp::verify() { |
| return verifyCoopMatrixAccess(*this, getPointer().getType(), |
| getObject().getType(), getMemoryOperandAttr(), |
| getAlignmentAttr()); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // spirv.KHR.CooperativeMatrixMulAdd |
| //===----------------------------------------------------------------------===// |
| |
| LogicalResult KHRCooperativeMatrixMulAddOp::verify() { |
| auto typeA = cast<spirv::CooperativeMatrixType>(getA().getType()); |
| auto typeB = cast<spirv::CooperativeMatrixType>(getB().getType()); |
| auto typeC = cast<spirv::CooperativeMatrixType>(getC().getType()); |
| |
| // Check element types. ODS enforces that `type(c) == type(result)`, so no |
| // need to check it here. |
| |
| // Check the 'use' part of the type against the operands and the result. |
| if (typeA.getUse() != CooperativeMatrixUseKHR::MatrixA) |
| return emitOpError("operand #0 must be of use 'MatrixA'"); |
| if (typeB.getUse() != CooperativeMatrixUseKHR::MatrixB) |
| return emitOpError("operand #1 must be of use 'MatrixB'"); |
| if (typeC.getUse() != CooperativeMatrixUseKHR::MatrixAcc) |
| return emitOpError("operand #2 must be of use 'MatrixAcc'"); |
| |
| // Check the 'scope' part of the type. |
| if (!llvm::all_equal({typeA.getScope(), typeB.getScope(), typeC.getScope()})) |
| return emitOpError("matrix scope mismatch"); |
| |
| // Check dimension sizes. We expect 'MxK * KxN + MxN -> MxN'. |
| if (typeA.getRows() != typeC.getRows()) |
| return emitOpError("matrix size mismatch on dimension 'M'"); |
| if (typeB.getColumns() != typeC.getColumns()) |
| return emitOpError("matrix size mismatch on dimension 'N'"); |
| if (typeA.getColumns() != typeB.getRows()) |
| return emitOpError("matrix size mismatch on dimension 'K'"); |
| |
| // The spec does not restrict the element types: |
| // > A, B, C, and Result Type need not necessarily have the same component |
| // > type, this is defined by the client API. |
| |
| // Check that if Cooperative Matrix Operands are provided, the element type |
| // is integer. |
| if (getMatrixOperands()) { |
| Type elementTypes[] = {typeA.getElementType(), typeB.getElementType(), |
| typeC.getElementType()}; |
| if (!llvm::all_of(elementTypes, llvm::IsaPred<IntegerType>)) { |
| return emitOpError("Matrix Operands require all matrix element types to " |
| "be Integer Types"); |
| } |
| } |
| |
| // Any further requirements need to be checked against VCE. |
| return success(); |
| } |
| |
| } // namespace mlir::spirv |