blob: 945716eea57543ddbebde1a8e868a2cc5c2f60fe [file] [log] [blame]
//===- LegalizeForLLVMExport.cpp - Prepare ArmSME for LLVM translation ----===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
#include "mlir/Conversion/LLVMCommon/Pattern.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/ArmSME/IR/ArmSME.h"
#include "mlir/Dialect/ArmSME/Transforms/Transforms.h"
#include "mlir/Dialect/ArmSME/Utils/Utils.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
using namespace mlir;
using namespace mlir::arm_sme;
namespace {
/// Insert 'llvm.aarch64.sme.za.enable' intrinsic at the start of 'func.func'
/// ops to enable the ZA storage array.
struct EnableZAPattern : public OpRewritePattern<func::FuncOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(func::FuncOp op,
PatternRewriter &rewriter) const final {
OpBuilder::InsertionGuard g(rewriter);
rewriter.setInsertionPointToStart(&op.front());
rewriter.create<arm_sme::aarch64_sme_za_enable>(op->getLoc());
rewriter.updateRootInPlace(op, [] {});
return success();
}
};
/// Insert 'llvm.aarch64.sme.za.disable' intrinsic before 'func.return' ops to
/// disable the ZA storage array.
struct DisableZAPattern : public OpRewritePattern<func::ReturnOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(func::ReturnOp op,
PatternRewriter &rewriter) const final {
OpBuilder::InsertionGuard g(rewriter);
rewriter.setInsertionPoint(op);
rewriter.create<arm_sme::aarch64_sme_za_disable>(op->getLoc());
rewriter.updateRootInPlace(op, [] {});
return success();
}
};
/// Lower 'arm_sme.zero' to SME intrinsics.
///
/// BEFORE:
/// ```mlir
/// %v = arm_sme.zero : vector<[4]x[4]xi32>
/// ```
///
/// AFTER:
/// ```mlir
/// %tile_id = arm_sme.get_tile_id : i32
/// %zero_mask = arith.shli %c17_i32, %tile_id : i32
/// "arm_sme.intr.zero"(%zero_mask) : (i32) -> ()
/// %v = arm_sme.cast_tile_to_vector %tile_id : i32 to vector<[4]x[4]xi32>
/// ```
///
/// The 'arm_sme.cast_tile_to_vector' (which models the return) and the
/// 'arith.shli' (which generates the mask) will be folded away after tile
/// allocation and canonization.
struct ZeroOpConversion : public ConvertOpToLLVMPattern<ZeroOp> {
using ConvertOpToLLVMPattern<ZeroOp>::ConvertOpToLLVMPattern;
LogicalResult
matchAndRewrite(ZeroOp zero, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto loc = zero.getLoc();
unsigned tileElementWidth =
zero.getVectorType().getElementType().getIntOrFloatBitWidth();
// Get Tile ID for the `zero` intrinsic.
auto tileId = rewriter.create<arm_sme::GetTileID>(
loc, rewriter.getIntegerType(tileElementWidth));
// Get the base mask for tile based on the element size.
// The base mask is just the mask to zero the first tile (of a size).
// These masks are derived from:
// https://developer.arm.com/documentation/ddi0602/2022-06/SME-Instructions/ZERO--Zero-a-list-of-64-bit-element-ZA-tiles-
auto baseMaskForSize = [&] {
switch (tileElementWidth) {
case 8:
// Zeroing the 8-bit ZA0.B tile is equivalent to zeroing all eight
// 64-bit element tiles named ZA0.D to ZA7.D.
return 0b1111'1111;
case 16:
// Zeroing the 16-bit ZA0.H tile is equivalent to zeroing 64-bit element
// tiles named ZA0.D, ZA2.D, ZA4.D, and ZA6.D.
// Shift this left once for ZA1.H.
return 0b0101'0101;
case 32:
// Zeroing the 32-bit ZA0.S tile is equivalent to zeroing 64-bit
// element tiles named ZA0.D and ZA4.D.
// Shift left by 1, 2, or 3 respectively for ZA1.S, ZA2.S, ZA3.S.
return 0b0001'0001;
case 64:
// Zeroing one of the a 64-bit tiles ZA0.D to ZA7.D just requires
// setting the bit for that tile.
return 0b0000'0001;
default:
llvm_unreachable("bad element size");
}
}();
auto maskType = rewriter.getI32Type();
auto baseMask = rewriter.create<arith::ConstantOp>(
loc, maskType, rewriter.getIntegerAttr(maskType, baseMaskForSize));
// The actual mask is just the base mask shifted by the tile ID.
// This will be folded to a constant after tile allocation.
//
// The shift is just derived from the layout of the tiles, and that the tile
// ID is the index of the tile. For example, looking at the 32-bit ZAx.S
// tiles:
//
// ZA0.S = ZA0.D and ZA4.D
// * Tile ID -> 0
// * Mask -> 00010001 = (00010001 << 0)
// ZA1.S = ZA1.D and ZA5.D
// * Tile ID -> 1
// * Mask -> 00100010 = (00010001 << 1)
// ZA2.S = ZA2.D and ZA6.D
// * Tile ID -> 2
// * Mask -> 01000100 = (00010001 << 2)
// ZA3.S = ZA3.D and ZA7.D
// * Tile ID -> 3
// * Mask -> 10001000 = (00010001 << 3)
//
// This holds for all tile sizes.
auto tileMask = rewriter.create<arith::ShLIOp>(
loc, baseMask, castTileIDToI32(tileId, loc, rewriter));
rewriter.create<arm_sme::aarch64_sme_zero>(loc, tileMask);
// Create `CastTileToVectorOp` to use as the output.
rewriter.replaceOpWithNewOp<arm_sme::CastTileToVector>(zero, zero.getType(),
tileId);
return success();
}
};
/// Lower `arm_sme.load_tile_slice` to SME intrinsics.
struct LoadTileSliceToArmSMELowering
: public ConvertOpToLLVMPattern<arm_sme::LoadTileSliceOp> {
using ConvertOpToLLVMPattern<
arm_sme::LoadTileSliceOp>::ConvertOpToLLVMPattern;
LogicalResult
matchAndRewrite(arm_sme::LoadTileSliceOp loadTileSliceOp,
arm_sme::LoadTileSliceOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto loc = loadTileSliceOp.getLoc();
auto tileType = loadTileSliceOp.getVectorType();
auto tileElementType = tileType.getElementType();
unsigned tileElementWidth = tileElementType.getIntOrFloatBitWidth();
// Create 'arm_sme.cast_vector_to_tile' to get a tile ID for the tile being
// loaded to.
auto tile = rewriter.create<arm_sme::CastVectorToTile>(
loc, rewriter.getIntegerType(tileElementWidth),
loadTileSliceOp.getTile());
Value ptr = this->getStridedElementPtr(loc, loadTileSliceOp.getMemRefType(),
adaptor.getBase(),
adaptor.getIndices(), rewriter);
auto tileSlice = loadTileSliceOp.getTileSliceIndex();
// Cast tile slice to i32 for intrinsic.
auto tileSliceI32 = rewriter.create<arith::IndexCastUIOp>(
loc, rewriter.getI32Type(), tileSlice);
// Create all active predicate mask.
auto one = rewriter.create<arith::ConstantOp>(
loc, rewriter.getI1Type(),
rewriter.getIntegerAttr(rewriter.getI1Type(), 1));
auto predTy = VectorType::get(tileType.getShape()[0], rewriter.getI1Type(),
/*scalableDims=*/{true});
auto allActiveMask = rewriter.create<vector::SplatOp>(loc, predTy, one);
auto tileI32 = castTileIDToI32(tile, loc, rewriter);
arm_sme::TileSliceLayout layout = loadTileSliceOp.getLayout();
// Create 'arm_sme.intr.ld1*.(horiz|vert)' intrinsic to load ZA tile slice.
if (layout == arm_sme::TileSliceLayout::Horizontal) {
switch (tileElementWidth) {
default:
llvm_unreachable("unexpected element type!");
case 8:
rewriter.create<arm_sme::aarch64_sme_ld1b_horiz>(
loc, allActiveMask, ptr, tileI32, tileSliceI32);
break;
case 16:
rewriter.create<arm_sme::aarch64_sme_ld1h_horiz>(
loc, allActiveMask, ptr, tileI32, tileSliceI32);
break;
case 32:
rewriter.create<arm_sme::aarch64_sme_ld1w_horiz>(
loc, allActiveMask, ptr, tileI32, tileSliceI32);
break;
case 64:
rewriter.create<arm_sme::aarch64_sme_ld1d_horiz>(
loc, allActiveMask, ptr, tileI32, tileSliceI32);
break;
case 128:
rewriter.create<arm_sme::aarch64_sme_ld1q_horiz>(
loc, allActiveMask, ptr, tileI32, tileSliceI32);
break;
}
} else {
switch (tileElementWidth) {
default:
llvm_unreachable("unexpected element type!");
case 8:
rewriter.create<arm_sme::aarch64_sme_ld1b_vert>(loc, allActiveMask, ptr,
tileI32, tileSliceI32);
break;
case 16:
rewriter.create<arm_sme::aarch64_sme_ld1h_vert>(loc, allActiveMask, ptr,
tileI32, tileSliceI32);
break;
case 32:
rewriter.create<arm_sme::aarch64_sme_ld1w_vert>(loc, allActiveMask, ptr,
tileI32, tileSliceI32);
break;
case 64:
rewriter.create<arm_sme::aarch64_sme_ld1d_vert>(loc, allActiveMask, ptr,
tileI32, tileSliceI32);
break;
case 128:
rewriter.create<arm_sme::aarch64_sme_ld1q_vert>(loc, allActiveMask, ptr,
tileI32, tileSliceI32);
break;
}
}
// The load intrinsics have no result, replace 'arm_sme.tile_load' with
// 'arm_sme.cast_tile_to_vector' to preserve dataflow.
rewriter.replaceOpWithNewOp<arm_sme::CastTileToVector>(loadTileSliceOp,
tileType, tile);
return success();
}
};
/// Lower for `arm_sme.store_tile_slice` to SME intrinsics.
struct StoreTileSliceToArmSMELowering
: public ConvertOpToLLVMPattern<arm_sme::StoreTileSliceOp> {
using ConvertOpToLLVMPattern<
arm_sme::StoreTileSliceOp>::ConvertOpToLLVMPattern;
LogicalResult
matchAndRewrite(arm_sme::StoreTileSliceOp storeTileSliceOp,
arm_sme::StoreTileSliceOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto loc = storeTileSliceOp.getLoc();
auto tileType = storeTileSliceOp.getVectorType();
auto tileElementType = tileType.getElementType();
unsigned tileElementWidth = tileElementType.getIntOrFloatBitWidth();
// Create 'arm_sme.cast_vector_to_tile' to get a tile ID for the vector
// being stored.
auto tile = rewriter.create<arm_sme::CastVectorToTile>(
loc, rewriter.getIntegerType(tileElementWidth),
storeTileSliceOp.getTile());
// Create 'arm_sme.intr.st1*.horiz' intrinsic to store ZA tile slice.
Value ptr = this->getStridedElementPtr(
loc, storeTileSliceOp.getMemRefType(), adaptor.getBase(),
adaptor.getIndices(), rewriter);
auto tileSlice = storeTileSliceOp.getTileSliceIndex();
// Cast tile slice to i32 for intrinsic.
auto tileSliceI32 = rewriter.create<arith::IndexCastUIOp>(
loc, rewriter.getI32Type(), tileSlice);
// Create all active predicate mask.
auto one = rewriter.create<arith::ConstantOp>(
loc, rewriter.getI1Type(),
rewriter.getIntegerAttr(rewriter.getI1Type(), 1));
auto predTy = VectorType::get(tileType.getShape()[0], rewriter.getI1Type(),
/*scalableDims=*/{true});
auto allActiveMask = rewriter.create<vector::SplatOp>(loc, predTy, one);
Value tileI32 = castTileIDToI32(tile, loc, rewriter);
arm_sme::TileSliceLayout layout = storeTileSliceOp.getLayout();
if (layout == arm_sme::TileSliceLayout::Horizontal) {
switch (tileElementWidth) {
default:
llvm_unreachable("unexpected element type!");
case 8:
rewriter.replaceOpWithNewOp<arm_sme::aarch64_sme_st1b_horiz>(
storeTileSliceOp, allActiveMask, ptr, tileI32, tileSliceI32);
break;
case 16:
rewriter.replaceOpWithNewOp<arm_sme::aarch64_sme_st1h_horiz>(
storeTileSliceOp, allActiveMask, ptr, tileI32, tileSliceI32);
break;
case 32:
rewriter.replaceOpWithNewOp<arm_sme::aarch64_sme_st1w_horiz>(
storeTileSliceOp, allActiveMask, ptr, tileI32, tileSliceI32);
break;
case 64:
rewriter.replaceOpWithNewOp<arm_sme::aarch64_sme_st1d_horiz>(
storeTileSliceOp, allActiveMask, ptr, tileI32, tileSliceI32);
break;
case 128:
rewriter.replaceOpWithNewOp<arm_sme::aarch64_sme_st1q_horiz>(
storeTileSliceOp, allActiveMask, ptr, tileI32, tileSliceI32);
break;
}
} else {
switch (tileElementWidth) {
default:
llvm_unreachable("unexpected element type!");
case 8:
rewriter.replaceOpWithNewOp<arm_sme::aarch64_sme_st1b_vert>(
storeTileSliceOp, allActiveMask, ptr, tileI32, tileSliceI32);
break;
case 16:
rewriter.replaceOpWithNewOp<arm_sme::aarch64_sme_st1h_vert>(
storeTileSliceOp, allActiveMask, ptr, tileI32, tileSliceI32);
break;
case 32:
rewriter.replaceOpWithNewOp<arm_sme::aarch64_sme_st1w_vert>(
storeTileSliceOp, allActiveMask, ptr, tileI32, tileSliceI32);
break;
case 64:
rewriter.replaceOpWithNewOp<arm_sme::aarch64_sme_st1d_vert>(
storeTileSliceOp, allActiveMask, ptr, tileI32, tileSliceI32);
break;
case 128:
rewriter.replaceOpWithNewOp<arm_sme::aarch64_sme_st1q_vert>(
storeTileSliceOp, allActiveMask, ptr, tileI32, tileSliceI32);
break;
}
}
return success();
}
};
/// Lower `arm_sme.move_vector_to_tile_slice` to SME intrinsics. Only horizontal
/// tile slices are currently supported.
struct MoveVectorToTileSliceToArmSMELowering
: public ConvertOpToLLVMPattern<arm_sme::MoveVectorToTileSliceOp> {
using ConvertOpToLLVMPattern<
arm_sme::MoveVectorToTileSliceOp>::ConvertOpToLLVMPattern;
LogicalResult
matchAndRewrite(arm_sme::MoveVectorToTileSliceOp moveVectorToTileSliceOp,
arm_sme::MoveVectorToTileSliceOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto loc = moveVectorToTileSliceOp.getLoc();
auto tileType = moveVectorToTileSliceOp.getTileType();
auto tileElementType = tileType.getElementType();
unsigned tileElementWidth = tileElementType.getIntOrFloatBitWidth();
// Create 'arm_sme.cast_vector_to_tile' to get a tile ID for the tile being
// loaded to.
auto tile = rewriter.create<arm_sme::CastVectorToTile>(
loc, rewriter.getIntegerType(tileElementWidth),
moveVectorToTileSliceOp.getTile());
auto tileSlice = moveVectorToTileSliceOp.getTileSliceIndex();
// Cast tile slice from index to i32 for intrinsic.
auto tileSliceI32 = rewriter.create<arith::IndexCastUIOp>(
loc, rewriter.getI32Type(), tileSlice);
// Create all active predicate mask.
auto one = rewriter.create<arith::ConstantOp>(
loc, rewriter.getI1Type(),
rewriter.getIntegerAttr(rewriter.getI1Type(), 1));
auto predTy = VectorType::get(tileType.getShape()[0], rewriter.getI1Type(),
/*scalableDims=*/{true});
auto allActiveMask = rewriter.create<vector::SplatOp>(loc, predTy, one);
auto tileI32 = castTileIDToI32(tile, loc, rewriter);
// Create 'arm_sme.intr.write.horiz' to write vector to tile slice.
rewriter.create<arm_sme::aarch64_sme_write_horiz>(
loc, tileI32, tileSliceI32, allActiveMask,
moveVectorToTileSliceOp.getVector());
// Intrinsic has no result, replace 'arm_sme.move_vector_to_tile_slice' with
// 'arm_sme.cast_tile_to_vector' to preserve dataflow.
rewriter.replaceOpWithNewOp<arm_sme::CastTileToVector>(
moveVectorToTileSliceOp, tileType, tile);
return success();
}
};
/// Lower `vector.outerproduct` to SME MOPA intrinsics.
///
/// Example:
///
/// %0 = vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind<add>}
/// : vector<[4]xf32>, vector<[4]xf32>
///
/// is converted to:
///
/// "arm_sme.intr.mopa"(%tile_id, %ptrue_s, %ptrue_s, %lhs, %rhs)
/// : (i32, vector<[4]xi1>, vector<[4]xi1>, vector<[4]xf32>,
/// vector<[4]xf32>) -> ()
///
/// Currently only supports FMOPA and BFMOPA (non-widening).
struct VectorOuterProductToArmSMELowering
: public ConvertOpToLLVMPattern<vector::OuterProductOp> {
using ConvertOpToLLVMPattern<vector::OuterProductOp>::ConvertOpToLLVMPattern;
LogicalResult
matchAndRewrite(vector::OuterProductOp outerProductOp,
vector::OuterProductOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto isSupportedType = [](VectorType vectorType) {
// TODO: the FP outer product instruction variants are predicated on
// different features [1]:
//
// * FMOPA (non-widening)
// * half-precision - +sme2p1,+sme-f16f16
// * single-precision - +sme
// * double-precision - +sme-f64f64
// * BFMOPA
// * half-precision - +sme2p1,+b16b16
//
// It should be possible to control lowering based on target features.
// [1] https://developer.arm.com/downloads/-/exploration-tools/feature-names-for-a-profile
if ((vectorType.getRank() != 2) || !vectorType.allDimsScalable())
return false;
auto elementType = vectorType.getElementType();
if (!elementType.isF16() && !elementType.isBF16() &&
!elementType.isF32() && !elementType.isF64())
return false;
unsigned minNumElts = arm_sme::MinStreamingVectorLengthInBits /
vectorType.getElementTypeBitWidth();
if (vectorType.getShape() != ArrayRef<int64_t>({minNumElts, minNumElts}))
return false;
return true;
};
auto resultVectorType = outerProductOp.getResultVectorType();
if (!isSupportedType(resultVectorType))
return outerProductOp.emitError("unsupported type");
vector::CombiningKind kind = outerProductOp.getKind();
if (kind != vector::CombiningKind::ADD)
// TODO: support subtract.
return outerProductOp.emitError("unsupported kind");
auto maskableOp =
cast<vector::MaskableOpInterface>(outerProductOp.getOperation());
if (maskableOp.isMasked())
// TODO: support masking.
return outerProductOp.emitError("masking is currently unsupported");
if (!isa<VectorType>(outerProductOp.getOperandTypeRHS()))
// AXPY operation not suited for SME.
return failure();
auto loc = outerProductOp.getLoc();
Value acc = outerProductOp.getAcc();
if (!acc)
// Initalize accumulator with zero.
acc = rewriter.create<arm_sme::ZeroOp>(loc, resultVectorType);
unsigned elementWidth = resultVectorType.getElementTypeBitWidth();
auto tileId = rewriter.create<arm_sme::CastVectorToTile>(
loc, rewriter.getIntegerType(elementWidth), acc);
// Create all active predicate mask.
auto one = rewriter.create<arith::ConstantOp>(
loc, rewriter.getI1Type(),
rewriter.getIntegerAttr(rewriter.getI1Type(), 1));
auto predTy =
VectorType::get(resultVectorType.getShape()[0], rewriter.getI1Type(),
/*scalableDims=*/{true});
auto allActiveMask = rewriter.create<vector::SplatOp>(loc, predTy, one);
auto tileI32 = castTileIDToI32(tileId, loc, rewriter);
// Create 'arm_sme.intr.mopa' outer product intrinsic.
rewriter.create<arm_sme::aarch64_sme_mopa>(
loc, tileI32, allActiveMask, allActiveMask, outerProductOp.getLhs(),
outerProductOp.getRhs());
// Create `CastTileToVectorOp` to use as the output.
rewriter.replaceOpWithNewOp<arm_sme::CastTileToVector>(
outerProductOp, resultVectorType, tileId);
return success();
}
};
} // namespace
void mlir::configureArmSMELegalizeForExportTarget(
LLVMConversionTarget &target) {
target.addLegalOp<
scf::ForOp, scf::YieldOp, arm_sme::CastTileToVector,
arm_sme::CastVectorToTile, arm_sme::aarch64_sme_zero,
arm_sme::aarch64_sme_str, arm_sme::aarch64_sme_ld1b_horiz,
arm_sme::aarch64_sme_ld1h_horiz, arm_sme::aarch64_sme_ld1w_horiz,
arm_sme::aarch64_sme_ld1d_horiz, arm_sme::aarch64_sme_ld1q_horiz,
arm_sme::aarch64_sme_st1b_horiz, arm_sme::aarch64_sme_st1h_horiz,
arm_sme::aarch64_sme_st1w_horiz, arm_sme::aarch64_sme_st1d_horiz,
arm_sme::aarch64_sme_st1q_horiz, arm_sme::aarch64_sme_ld1b_vert,
arm_sme::aarch64_sme_ld1h_vert, arm_sme::aarch64_sme_ld1w_vert,
arm_sme::aarch64_sme_ld1d_vert, arm_sme::aarch64_sme_ld1q_vert,
arm_sme::aarch64_sme_st1b_vert, arm_sme::aarch64_sme_st1h_vert,
arm_sme::aarch64_sme_st1w_vert, arm_sme::aarch64_sme_st1d_vert,
arm_sme::aarch64_sme_st1q_vert, arm_sme::aarch64_sme_write_horiz,
arm_sme::aarch64_sme_mopa, arm_sme::aarch64_sme_za_enable,
arm_sme::aarch64_sme_za_disable>();
target.addLegalOp<GetTileID>();
target.addIllegalOp<vector::OuterProductOp>();
// Mark 'func.func' ops as legal if either:
// 1. no 'arm_za' function attribute is present.
// 2. the 'arm_za' function attribute is present and the first op in the
// function is an 'arm_sme::aarch64_sme_za_enable' intrinsic.
target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp funcOp) {
if (funcOp.isDeclaration())
return true;
auto firstOp = funcOp.getBody().front().begin();
return !funcOp->hasAttr("arm_za") ||
isa<arm_sme::aarch64_sme_za_enable>(firstOp);
});
// Mark 'func.return' ops as legal if either:
// 1. no 'arm_za' function attribute is present.
// 2. the 'arm_za' function attribute is present and there's a preceding
// 'arm_sme::aarch64_sme_za_disable' intrinsic.
target.addDynamicallyLegalOp<func::ReturnOp>([&](func::ReturnOp returnOp) {
bool hasDisableZA = false;
auto funcOp = returnOp->getParentOp();
funcOp->walk<WalkOrder::PreOrder>(
[&](arm_sme::aarch64_sme_za_disable op) { hasDisableZA = true; });
return !funcOp->hasAttr("arm_za") || hasDisableZA;
});
}
void mlir::populateArmSMELegalizeForLLVMExportPatterns(
LLVMTypeConverter &converter, RewritePatternSet &patterns) {
patterns.add<EnableZAPattern, DisableZAPattern>(patterns.getContext());
patterns
.add<ZeroOpConversion, StoreTileSliceToArmSMELowering,
LoadTileSliceToArmSMELowering, MoveVectorToTileSliceToArmSMELowering,
VectorOuterProductToArmSMELowering>(converter);
}