[MLIR][Linalg] Use Top-Down traversal to safely optimize multi-use producer fusion (#172216)
Switches the greedy rewrite traversal for the multi-use producer fusion
pattern to Top-Down (Pre-Order).
The previous Bottom-Up (Post-Order) traversal led to a critical SSA
violation when a producer (P) had multiple users (I and C) and the first
user (I) appeared before the current consumer (C) in the block.
Processing the outer consumer (C) first and attempting to fuse P into C
would create a new fused operation, F. The rewrite would attempt to
replace P's result (used by I) with the output of F. However, since I is
located before F in the block, this replacement breaks SSA dominance
rules, leading to a crash. To ensure correctness, the first use (I) must
be processed and fused before the second use (C). Using Top-Down
traversal ensures that operations are visited and rewritten in the
correct flow order.
Take a look at this example, which represents a three-operation chain
where the first operation, P (**%13:2**), has two users: an intermediate
operation I (**%15:2**) and a final consumer C (**%17:2**):
```
#map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
module {
func.func @avgpool2d_pad_top(%arg0: tensor<1x32x32x8xf32>) -> tensor<1x32x32x8xf32> attributes {llvm.emit_c_interface} {
%0 = llvm.mlir.constant(0.000000e+00 : f32) : f32
%1 = llvm.mlir.constant(31 : index) : i64
%11 = tensor.empty() : tensor<1x32x32x8xf32>
%12 = tensor.empty() : tensor<1x32x32x8xindex>
%13:2 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg0 : tensor<1x32x32x8xf32>) outs(%11, %12 : tensor<1x32x32x8xf32>, tensor<1x32x32x8xindex>) {
^bb0(%in: f32, %out: f32, %out_0: index):
%59 = linalg.index 1 : index
linalg.yield %0, %59 : f32, index
} -> (tensor<1x32x32x8xf32>, tensor<1x32x32x8xindex>)
%14 = tensor.empty() : tensor<1x32x32x8xi64>
%15:2 = linalg.generic {indexing_maps = [#map, #map, #map, #map], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg0, %13#1 : tensor<1x32x32x8xf32>, tensor<1x32x32x8xindex>) outs(%11, %14 : tensor<1x32x32x8xf32>, tensor<1x32x32x8xi64>) {
^bb0(%in: f32, %in_0: index, %out: f32, %out_1: i64):
%59 = builtin.unrealized_conversion_cast %in_0 : index to i64
linalg.yield %0, %59 : f32, i64
} -> (tensor<1x32x32x8xf32>, tensor<1x32x32x8xi64>)
%16 = tensor.empty() : tensor<1x32x32x8xi64>
%17:2 = linalg.generic {indexing_maps = [#map, #map, #map, #map, #map], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg0, %13#1, %15#1 : tensor<1x32x32x8xf32>, tensor<1x32x32x8xindex>, tensor<1x32x32x8xi64>) outs(%11, %16 : tensor<1x32x32x8xf32>, tensor<1x32x32x8xi64>) {
^bb0(%in: f32, %in_0: index, %in_1: i64, %out: f32, %out_2: i64):
%59 = llvm.sub %1, %in_1 : i64
linalg.yield %0, %59 : f32, i64
} -> (tensor<1x32x32x8xf32>, tensor<1x32x32x8xi64>)
return %17 : tensor<1x32x32x8xf32>
}
}
```
If fused op is inserted at the position of **%17**, the rewrite
mechanism must update all users of P's result (**%13**). Since the
intermediate user I (**%15**) is before the final consumer C (**%17**)
in the block, renaming I's operand (which is **%13**) to the output of
the new fused operation results in a violation of SSA dominance, causing
the compiler to crash.
Issue: [#131446](https://github.com/llvm/llvm-project/issues/131446)
---------
Co-authored-by: Milos Poletanovic <mpoletanovic@syrmia.com>
Co-authored-by: Milos Poletanovic <milos.poletanovic@htecgroup.com>Welcome to the LLVM project!
This repository contains the source code for LLVM, a toolkit for the construction of highly optimized compilers, optimizers, and run-time environments.
The LLVM project has multiple components. The core of the project is itself called “LLVM”. This contains all of the tools, libraries, and header files needed to process intermediate representations and convert them into object files. Tools include an assembler, disassembler, bitcode analyzer, and bitcode optimizer.
C-like languages use the Clang frontend. This component compiles C, C++, Objective-C, and Objective-C++ code into LLVM bitcode -- and from there into object files, using LLVM.
Other components include: the libc++ C++ standard library, the LLD linker, and more.
Consult the Getting Started with LLVM page for information on building and running LLVM.
For information on how to contribute to the LLVM project, please take a look at the Contributing to LLVM guide.
Join the LLVM Discourse forums, Discord chat, LLVM Office Hours or Regular sync-ups.
The LLVM project has adopted a code of conduct for participants to all modes of communication within the project.