[mlir][vector] Update `CombineContractBroadcastMask` (#140050)
This patch updates `CombineContractBroadcastMask` to inherit from
`MaskableOpRewritePattern`, enabling it to handle masked
`vector.contract` operations. The pattern rewrites:
```mlir
%a = vector.broadcast %a_bc
%res vector.contract %a_bc, %b, ...
```
into:
```mlir
// Move the broadcast into vector.contract (by updating the indexing
// maps)
%res vector.contract %a, %b, ...
```
The main challenge is supporting cases where the pattern drops a leading
unit dimension. For example:
```mlir
func.func @contract_broadcast_unit_dim_reduction_masked(
%arg0 : vector<8x4xi32>,
%arg1 : vector<8x4xi32>,
%arg2 : vector<8x8xi32>,
%mask: vector<1x8x8x4xi1>) -> vector<8x8xi32> {
%0 = vector.broadcast %arg0 : vector<8x4xi32> to vector<1x8x4xi32>
%1 = vector.broadcast %arg1 : vector<8x4xi32> to vector<1x8x4xi32>
%result = vector.mask %mask {
vector.contract {
indexing_maps = [#map0, #map1, #map2],
iterator_types = ["reduction", "parallel", "parallel", "reduction"],
kind = #vector.kind<add>
} %0, %1, %arg2 : vector<1x8x4xi32>, vector<1x8x4xi32> into vector<8x8xi32>
} : vector<1x8x8x4xi1> -> vector<8x8xi32>
return %result : vector<8x8xi32>
}
```
Here, the leading unit dimension is dropped. To handle this, the mask is
cast to the correct shape using a `vector.shape_cast`:
```mlir
func.func @contract_broadcast_unit_dim_reduction_masked(
%arg0: vector<8x4xi32>,
%arg1: vector<8x4xi32>,
%arg2: vector<8x8xi32>,
%arg3: vector<1x8x8x4xi1>) -> vector<8x8xi32> {
%mask_sc = vector.shape_cast %arg3 : vector<1x8x8x4xi1> to vector<8x8x4xi1>
%res = vector.mask %mask_sc {
vector.contract {
indexing_maps = [#map, #map1, #map2],
iterator_types = ["parallel", "parallel", "reduction"],
kind = #vector.kind<add>
} %arg0, %arg1, %mask_sc : vector<8x4xi32>, vector<8x4xi32> into vector<8x8xi32>
} : vector<8x8x4xi1> -> vector<8x8xi32>
return %res : vector<8x8xi32>
}
```
While this isn't ideal - since it introduces a `vector.shape_cast` that
must be cleaned up later - it reflects the best we can do once the input
reaches `CombineContractBroadcastMask`. A more robust solution may
involve simplifying the input earlier. I am leaving that as a TODO for
myself to explore this further. Posting this now to unblock downstream
work.
LIMITATIONS
Currently, this pattern assumes:
* Only leading dimensions are dropped in the mask.
* All dropped dimensions must be unit-sized.Welcome to the LLVM project!
This repository contains the source code for LLVM, a toolkit for the construction of highly optimized compilers, optimizers, and run-time environments.
The LLVM project has multiple components. The core of the project is itself called “LLVM”. This contains all of the tools, libraries, and header files needed to process intermediate representations and convert them into object files. Tools include an assembler, disassembler, bitcode analyzer, and bitcode optimizer.
C-like languages use the Clang frontend. This component compiles C, C++, Objective-C, and Objective-C++ code into LLVM bitcode -- and from there into object files, using LLVM.
Other components include: the libc++ C++ standard library, the LLD linker, and more.
Consult the Getting Started with LLVM page for information on building and running LLVM.
For information on how to contribute to the LLVM project, please take a look at the Contributing to LLVM guide.
Join the LLVM Discourse forums, Discord chat, LLVM Office Hours or Regular sync-ups.
The LLVM project has adopted a code of conduct for participants to all modes of communication within the project.