blob: 24e6a9c284e26243e4e1e7f499251f3b3d477819 [file] [log] [blame]
//===-- XeVMDialect.cpp - XeVM dialect registration -------------*- C++ -*-===//
//
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/LLVMIR/XeVMDialect.h"
#include "mlir/Dialect/GPU/IR/CompilationInterfaces.h"
#include "mlir/Dialect/Utils/StaticValueUtils.h"
#include "mlir/IR/DialectImplementation.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/FileSystem.h"
#include "llvm/Support/MathExtras.h"
using namespace mlir;
using namespace mlir::xevm;
#include "mlir/Dialect/LLVMIR/XeVMOpsDialect.cpp.inc"
#include "mlir/Dialect/LLVMIR/XeVMOpsEnums.cpp.inc"
namespace {
static constexpr uint32_t subgroupSize = 16;
template <typename Op>
LogicalResult verifyMatrixInput(Op op) {
static_assert(llvm::is_one_of<Op, BlockLoad2dOp, BlockStore2dOp,
BlockPrefetch2dOp>::value,
"Unexpected template parameter");
std::optional<int64_t> width = getConstantIntValue(op.getBaseWidth());
std::optional<int64_t> pitch = getConstantIntValue(op.getBasePitch());
if (pitch && width && *pitch < *width)
return op->emitOpError(
"4th operand (base pitch) should be >= 2nd operand (base width)");
uint32_t elemSize = op.getElemSizeInBits();
if (elemSize < 8 || !llvm::isPowerOf2_32(elemSize) || elemSize > 32)
return op->emitOpError("expecting 'elem_size_in_bits' to be 8, 16, or 32");
uint32_t tileHeight = op.getTileHeight();
if (tileHeight > 32 || !llvm::isPowerOf2_32(tileHeight))
return op->emitOpError("expecting tile_height to be 1, 2, 4, 8, 16, or 32");
uint32_t vBlocks = op.getVBlocks();
if (vBlocks > 8 || !llvm::isPowerOf2_32(vBlocks))
return op->emitOpError("expecting v_blocks to be 1, 2, 4, or 8");
return success();
}
LogicalResult verify2DBlockLoadRestriction(BlockLoad2dOp op) {
VectorType resTy = op.getRes().getType();
if (!resTy.getElementType().isIntOrFloat())
return op.emitOpError()
<< "expecting result element type to be int or float";
unsigned resElemTySize = resTy.getElementType().getIntOrFloatBitWidth();
unsigned resSize = resTy.getNumElements() * resElemTySize;
unsigned expectedSize = op.getElemSizeInBits() * op.getTileHeight() *
op.getTileWidth() * op.getVBlocks() / subgroupSize;
if (resSize != expectedSize)
return op.emitOpError() << "result size of " << resSize
<< " bits does not match the expected size of "
<< expectedSize << " bits";
if (op.getTranspose() && op.getPackRegister())
return op.emitOpError("transpose and pack_register are mutually exclusive");
if (!op.getTranspose() && !op.getPackRegister()) {
uint32_t tileHeight = op.getTileHeight();
if (tileHeight < 1 || tileHeight > 32)
return op.emitOpError("expecting tile_height to be between 1 and 32");
uint32_t tileWidth = op.getTileWidth();
uint32_t vBlocks = op.getVBlocks();
switch (op.getElemSizeInBits()) {
case 8:
if (tileWidth < 4 || tileWidth > 64)
return op.emitOpError("expecting tile_width to be between 4 and 64");
if (vBlocks != 1 && vBlocks != 2 && vBlocks != 4)
return op.emitOpError("expecting v_blocks to be 1, 2, or 4");
if (tileWidth * vBlocks > 64)
return op.emitOpError(
"tile_width * v_blocks should be less than or equal "
"to 64 for 8 bit elements");
break;
case 16:
if (tileWidth < 2 || tileWidth > 32)
return op.emitOpError("expecting tile_width to be between 2 and 32");
if (vBlocks != 1 && vBlocks != 2 && vBlocks != 4)
return op.emitOpError("expecting v_blocks to be 1, 2, or 4");
if (tileWidth * vBlocks > 32)
return op.emitOpError(
"tile_width * v_blocks should be less than or equal "
"to 32 for 16 bit elements");
break;
case 32:
if (tileWidth < 1 || tileWidth > 16)
return op.emitOpError("expecting tile_width to be between 1 and 16");
if (vBlocks != 1 && vBlocks != 2)
return op.emitOpError("expecting v_blocks to be 1 or 2");
if (tileWidth * vBlocks > 16)
return op.emitOpError(
"tile_width * v_blocks should be less than or equal "
"to 16 for 32 bit elements");
break;
case 64:
if (tileWidth < 1 || tileWidth > 8)
return op.emitOpError("expecting tile_width to be between 1 and 8");
if (vBlocks != 1)
return op.emitOpError("expecting v_blocks to be 1");
break;
default:
return op.emitOpError(
"expecting elem_size_in_bits to be 8, 16, 32, or 64");
}
return success();
}
if (op.getTranspose()) {
assert(!op.getPackRegister() && "Expecting pack_register should be false");
uint32_t vBlocks = op.getVBlocks();
if (vBlocks != 1)
return op.emitOpError("expecting v_blocks to be 1");
uint32_t tileHeight = op.getTileHeight();
uint32_t tileWidth = op.getTileWidth();
switch (op.getElemSizeInBits()) {
case 32:
if (tileHeight < 1 || tileHeight > 32)
return op.emitOpError("expecting tile_height to be between 1 and 32");
if (tileWidth < 1 || tileWidth > 8)
return op.emitOpError("expecting tile_width to be between 1 and 8");
break;
case 64:
if (tileHeight != 8)
return op.emitOpError(
"expecting tile_height to be 8 for 64 bit elements");
if (tileWidth != 1 && tileWidth != 2 && tileWidth != 4)
return op.emitOpError("expecting tile_width to be 1, 2, or 4");
break;
default:
return op.emitOpError("transpose is only supported for 32 and 64 bit "
"elements");
}
return success();
}
assert(op.getPackRegister() && !op.getTranspose() &&
"Expecting pack_register should be true and transpose should be "
"false");
uint32_t vBlocks = op.getVBlocks();
if (vBlocks != 1 && vBlocks != 2 && vBlocks != 4)
return op.emitOpError("expecting v_blocks to be 1, 2, or 4");
uint32_t tileHeight = op.getTileHeight();
uint32_t tileWidth = op.getTileWidth();
switch (op.getElemSizeInBits()) {
case 8:
if (tileHeight < 4 || tileHeight > 32)
return op.emitOpError("expecting tile_height to be between 4 and 32");
if (tileWidth < 4 || tileWidth > 16)
return op.emitOpError("expecting tile_width to be between 4 and 16");
break;
case 16:
if (tileHeight < 2 || tileHeight > 32)
return op.emitOpError("expecting tile_height to be between 2 and 32");
if (tileWidth < 2 || tileWidth > 16)
return op.emitOpError("expecting tile_width to be between 2 and 16");
if (tileWidth * vBlocks > 32)
return op.emitOpError(
"tile_width * v_blocks should be less than or equal "
"to 32 for 16 bit elements");
break;
default:
return op.emitOpError("pack_register is only supported for 8 and 16 bit "
"elements");
}
return success();
}
static LogicalResult verify2DBlockStoreRestriction(BlockStore2dOp op) {
uint32_t tileHeight = op.getTileHeight();
if (tileHeight < 1 || tileHeight > 8)
return op.emitOpError("expecting tile_height to be between 1 and 8");
uint32_t tileWidth = op.getTileWidth();
switch (op.getElemSizeInBits()) {
case 8:
if (tileWidth < 4 || tileWidth > 64)
return op.emitOpError("expecting tile_width to be between 4 and 64");
break;
case 16:
if (tileWidth < 2 || tileWidth > 32)
return op.emitOpError("expecting tile_width to be between 2 and 32");
break;
case 32:
if (tileWidth < 1 || tileWidth > 16)
return op.emitOpError("expecting tile_width to be between 1 and 16");
break;
case 64:
if (tileWidth < 1 || tileWidth > 8)
return op.emitOpError("expecting tile_width to be between 1 and 8");
break;
default:
return op.emitOpError("expecting elem_size_in_bits to be 8, 16, 32, or 64");
}
uint32_t vBlocks = op.getVBlocks();
if (vBlocks != 1)
return op.emitOpError("expecting v_blocks to be 1");
return success();
}
} // namespace
LogicalResult BlockLoad2dOp::verify() {
if (verify2DBlockLoadRestriction(*this).failed())
return failure();
if (verifyMatrixInput(*this).failed())
return failure();
VectorType resTy = getRes().getType();
if (!resTy.getElementType().isIntOrFloat())
return emitOpError() << "expecting result element type to be int of float";
unsigned resElemTySize = resTy.getElementType().getIntOrFloatBitWidth();
if (getElemSizeInBits() == 32 || getPackRegister()) {
if (resElemTySize != 32)
return emitOpError() << "expecting result element type to be 32 bits";
}
uint32_t tileWidth = getTileWidth();
if (getPackRegister()) {
if (tileWidth != 16)
return emitOpError(
"tile_width when pack_register is true should be equal "
"to subgroup size (16 elements)");
return success();
}
return success();
}
LogicalResult BlockStore2dOp::verify() {
if (verify2DBlockStoreRestriction(*this).failed())
return failure();
if (verifyMatrixInput(*this).failed())
return failure();
uint32_t tileWidth = getTileWidth();
switch (getElemSizeInBits()) {
case 8:
if (tileWidth != 16 && tileWidth != 32)
return emitOpError("tile_width for 8 bit elements should be equal to "
"16 or 32");
break;
case 16:
if (tileWidth != 16)
return emitOpError("tile_width for 16 bit elements should be equal "
"to 16");
break;
case 32:
if (tileWidth != 16)
return emitOpError("tile_width for 32 bit elements should be equal "
"to 16");
break;
default:
llvm_unreachable("unexpected element size");
}
return success();
}
LogicalResult BlockPrefetch2dOp::verify() {
if (verifyMatrixInput(*this).failed())
return failure();
uint32_t tileWidth = getTileWidth();
switch (getElemSizeInBits()) {
case 8:
if (tileWidth != 16 && tileWidth != 32)
return emitOpError("tile_width for 8 bit elements should be equal to "
"16 or 32");
break;
case 16:
if (tileWidth != 16)
return emitOpError("tile_width for 16 bit elements should be equal "
"to 16");
break;
case 32:
if (tileWidth != 8 && tileWidth != 16)
return emitOpError(
"tile_width for 32 bit elements should be equal to 8 or 16");
break;
default:
llvm_unreachable("unexpected element size");
}
return success();
}
LogicalResult MMAOp::verify() {
if (getC()) {
if (getResult().getType() != getC().getType())
return emitOpError("type of C operand must match result type");
}
return success();
}
LogicalResult
XeVMTargetAttr::verify(function_ref<InFlightDiagnostic()> emitError, int O,
StringRef triple, StringRef chip, DictionaryAttr flags,
ArrayAttr linkFiles) {
if (O < 0 || O > 3) {
return emitError()
<< "The optimization level must be a number between 0 and 3.";
}
if (triple.empty()) {
return emitError() << "The target triple cannot be empty.";
}
if (chip.empty()) {
return emitError() << "The target chip cannot be empty.";
}
if (linkFiles) {
for (Attribute fileAttr : linkFiles) {
if (auto fileStrAttr = llvm::dyn_cast<StringAttr>(fileAttr)) {
StringRef filePath = fileStrAttr.getValue();
if (filePath.empty()) {
return emitError() << "File paths in linkFiles cannot be empty.";
}
if (!llvm::sys::fs::exists(filePath)) {
return emitError() << "File '" << filePath << "' does not exist.";
}
}
}
}
return success();
}
void XeVMDialect::initialize() {
addOperations<
#define GET_OP_LIST
#include "mlir/Dialect/LLVMIR/XeVMOps.cpp.inc"
>();
addAttributes<
#define GET_ATTRDEF_LIST
#include "mlir/Dialect/LLVMIR/XeVMOpsAttributes.cpp.inc"
>();
declarePromisedInterface<mlir::gpu::TargetAttrInterface,
mlir::xevm::XeVMTargetAttr>();
}
#define GET_OP_CLASSES
#include "mlir/Dialect/LLVMIR/XeVMOps.cpp.inc"
#define GET_ATTRDEF_CLASSES
#include "mlir/Dialect/LLVMIR/XeVMOpsAttributes.cpp.inc"