blob: c0c4394f73d4a3de89c97381b3583c85f211c2ad [file] [log] [blame]
//===- XeGPUPropagateLayout.cpp - XeGPU Layout Propagation ------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h"
#include "mlir/Analysis/DataFlow/SparseAnalysis.h"
#include "mlir/Analysis/DataFlow/Utils.h"
#include "mlir/Analysis/DataFlowFramework.h"
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
#include "mlir/Dialect/XeGPU/IR/XeGPUTargetInfo.h"
#include "mlir/Dialect/XeGPU/Transforms/Passes.h"
#include "mlir/Dialect/XeGPU/Utils/XeGPUUtils.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/Value.h"
#include "mlir/IR/Visitors.h"
#include "mlir/Interfaces/ControlFlowInterfaces.h"
#include "mlir/Interfaces/FunctionInterfaces.h"
#include "mlir/Support/LLVM.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/InterleavedRange.h"
#include "llvm/Support/LogicalResult.h"
#include "llvm/Support/raw_ostream.h"
namespace mlir {
namespace xegpu {
#define GEN_PASS_DEF_XEGPUPROPAGATELAYOUT
#include "mlir/Dialect/XeGPU/Transforms/Passes.h.inc"
} // namespace xegpu
} // namespace mlir
#define DEBUG_TYPE "xegpu-propagate-layout"
#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
using namespace mlir;
using namespace mlir::dataflow;
namespace {
//===----------------------------------------------------------------------===//
// Layout
//===----------------------------------------------------------------------===//
/// Helper class to store the ND layout of lanes within a subgroup and data
/// owned by each lane.
struct Layout {
SmallVector<int64_t, 3> layout;
Layout() = default;
Layout(std::initializer_list<int64_t> list) : layout(list) {}
void print(llvm::raw_ostream &os) const;
size_t size() const { return layout.size(); }
};
void Layout::print(llvm::raw_ostream &os) const {
os << llvm::interleaved_array(layout);
}
/// LaneLayout represents the logical layout of lanes within a subgroup when it
/// accesses some value. LaneData represents the logical layout of data owned by
/// each work item.
using LaneLayout = Layout;
using LaneData = Layout;
//===----------------------------------------------------------------------===//
// LayoutInfo
//===----------------------------------------------------------------------===//
/// Helper class for tracking the analysis state of an mlir value. For layout
/// propagation, the analysis state is simply the lane_layout and lane_data of
/// each value. Purpose of this analysis to propagate some unique layout for
/// each value in the program starting from a set of anchor operations (like
/// DPAS, StoreNd, etc.).
///
/// Given this, LayoutInfo satisifies the following properties:
/// 1) A LayoutInfo value can be in one of two states - `assigned` or `not
/// assigned`.
/// 2) Two LayoutInfo values are equal if they are both assigned or
/// both not assigned. The concrete value of assigned state does not matter.
/// 3) The meet operator works as follows:
/// - If current state is assigned, return the current state. (already
/// a unique layout is assigned. don't change it)
/// - Otherwise, return the other state.
struct LayoutInfo {
private:
LaneLayout laneLayout;
LaneData laneData;
xegpu::LayoutAttr layoutAttr;
public:
LayoutInfo() = default;
LayoutInfo(const LaneLayout &layout, const LaneData &data)
: laneLayout(layout), laneData(data) {}
// Two lattice values are equal if they have `some` layout. The actual
// content of the layout does not matter.
bool operator==(const LayoutInfo &other) const {
return this->isAssigned() == other.isAssigned();
}
static LayoutInfo meet(const LayoutInfo &lhs, const LayoutInfo &rhs);
static LayoutInfo join(const LayoutInfo &lhs, const LayoutInfo &rhs);
void print(raw_ostream &os) const;
bool isAssigned() const {
return laneLayout.size() > 0 && laneData.size() > 0;
}
LayoutInfo getTransposedLayout(ArrayRef<int64_t> permutation) const;
const LaneLayout &getLayout() const { return laneLayout; }
const LaneData &getData() const { return laneData; }
ArrayRef<int64_t> getLayoutAsArrayRef() const { return laneLayout.layout; }
ArrayRef<int64_t> getDataAsArrayRef() const { return laneData.layout; }
};
void LayoutInfo::print(raw_ostream &os) const {
if (isAssigned()) {
os << "lane_layout: ";
laneLayout.print(os);
os << ", lane_data: ";
laneData.print(os);
} else {
os << "Not assigned.";
}
}
LayoutInfo LayoutInfo::meet(const LayoutInfo &lhs, const LayoutInfo &rhs) {
if (!lhs.isAssigned())
return rhs;
return lhs;
}
/// Since this is a backward analysis, join method is not used.
LayoutInfo LayoutInfo::join(const LayoutInfo &lhs, const LayoutInfo &rhs) {
llvm_unreachable("Join should not be triggered by layout propagation.");
}
/// Get the transposed layout according to the given permutation.
LayoutInfo
LayoutInfo::getTransposedLayout(ArrayRef<int64_t> permutation) const {
if (!isAssigned())
return {};
LaneLayout newLayout;
LaneData newData;
for (int64_t idx : permutation) {
newLayout.layout.push_back(laneLayout.layout[idx]);
newData.layout.push_back(laneData.layout[idx]);
}
return LayoutInfo(newLayout, newData);
}
//===----------------------------------------------------------------------===//
// LayoutInfoLattice
//===----------------------------------------------------------------------===//
/// Lattice holding the LayoutInfo for each value.
struct LayoutInfoLattice : public Lattice<LayoutInfo> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(LayoutInfoLattice)
using Lattice::Lattice;
};
/// Helper Functions to get default layouts. A `default layout` is a layout that
/// is assigned to a value when the layout is not fixed by some anchor operation
/// (like DPAS).
/// Helper Function to get the default layout for uniform values like constants.
/// For 1D vector, lane_layout is [subgroupSize] and lane_data is [1].
/// For 2D vector, lane_layout is [1, subgroupSize] and lane_data is [1, 1].
static LayoutInfo getDefaultSIMTLayoutInfo(unsigned rank) {
assert((rank == 1 || rank == 2) && "Expected 1D or 2D vector.");
if (rank == 1)
return LayoutInfo(LaneLayout({xegpu::targetinfo::subgroupSize}),
LaneData({1}));
return LayoutInfo(LaneLayout({1, xegpu::targetinfo::subgroupSize}),
LaneData({1, 1}));
}
/// Helper to get the default layout for a vector type.
static LayoutInfo getDefaultSIMTLayoutInfo(VectorType vectorTy,
bool isScattered = false) {
// Expecting a 1D or 2D vector.
assert((vectorTy.getRank() == 1 || vectorTy.getRank() == 2) &&
"Expected 1D or 2D vector.");
// Expecting int or float element type.
assert(vectorTy.getElementType().isIntOrFloat() &&
"Expected int or float element type.");
// If the rank is 1, then return default layout for 1D vector.
if (vectorTy.getRank() == 1)
return getDefaultSIMTLayoutInfo(1);
// Packing factor is determined by the element type bitwidth.
int packingFactor = 1;
unsigned bitwidth = vectorTy.getElementType().getIntOrFloatBitWidth();
if (isScattered) {
packingFactor =
bitwidth < xegpu::targetinfo::packedSizeInBitsForGatherScatter
? xegpu::targetinfo::packedSizeInBitsForGatherScatter / bitwidth
: 1;
return LayoutInfo(LaneLayout({xegpu::targetinfo::subgroupSize, 1}),
LaneData({1, packingFactor}));
}
if (bitwidth < xegpu::targetinfo::packedSizeInBitsForDefault)
packingFactor = xegpu::targetinfo::packedSizeInBitsForDefault / bitwidth;
return LayoutInfo(LaneLayout({1, xegpu::targetinfo::subgroupSize}),
LaneData({1, packingFactor}));
}
/// Helper to get the default layout for a vector type.
static LayoutInfo getDefaultSIMTLayoutInfo(xegpu::TensorDescType tdescTy,
bool isScattered = false) {
// Expecting a 1D or 2D vector.
assert((tdescTy.getRank() == 1 || tdescTy.getRank() == 2) &&
"Expected 1D or 2D TensorDesc.");
// Expecting int or float element type.
assert(tdescTy.getElementType().isIntOrFloat() &&
"Expected int or float element type.");
// If the rank is 1, then return default layout for 1D vector.
if (tdescTy.getRank() == 1)
return getDefaultSIMTLayoutInfo(1);
// Packing factor is determined by the element type bitwidth.
unsigned bitwidth = tdescTy.getElementType().getIntOrFloatBitWidth();
if (isScattered) {
int packingFactor =
bitwidth < xegpu::targetinfo::packedSizeInBitsForGatherScatter
? xegpu::targetinfo::packedSizeInBitsForGatherScatter / bitwidth
: 1;
return LayoutInfo(LaneLayout({xegpu::targetinfo::subgroupSize, 1}),
LaneData({1, packingFactor}));
}
int packingFactor =
(bitwidth < xegpu::targetinfo::packedSizeInBitsForDefault)
? xegpu::targetinfo::packedSizeInBitsForDefault / bitwidth
: 1;
return LayoutInfo(LaneLayout({1, xegpu::targetinfo::subgroupSize}),
LaneData({1, packingFactor}));
}
/// Helper Function to get the expected layouts for DPAS operands. `lane_data`
/// is set according to the following criteria:
/// * For A operand, the data must be packed in minimum
/// `packedSizeInBitsForDefault`
/// * For B operand, the data must be packed in minimum
/// `packedSizeInBitsForDpasB`
static LayoutInfo getSIMTLayoutInfoForDPASOperand(VectorType vectorTy,
unsigned operandNum) {
Type elementTy = vectorTy.getElementType();
assert(elementTy.isIntOrFloat() &&
"Expected int or float type in DPAS operands");
LaneLayout layout({1, xegpu::targetinfo::subgroupSize});
// For B operand, data must be packed in minimum `packedDpasBSizeInBits` and
// must have the VNNI format.
if (operandNum == 1 && elementTy.getIntOrFloatBitWidth() <
xegpu::targetinfo::packedSizeInBitsForDpasB) {
LaneData data({xegpu::targetinfo::packedSizeInBitsForDpasB /
elementTy.getIntOrFloatBitWidth(),
1});
return LayoutInfo(layout, data);
}
// Otherwise, return the default layout for the vector type.
return getDefaultSIMTLayoutInfo(vectorTy);
}
//===----------------------------------------------------------------------===//
// LayoutInfoPropagation
//===----------------------------------------------------------------------===//
/// Backward data flow analysis to propagate the lane_layout and lane_data of
/// each value in the program. Currently, the layouts for operands DPAS,
/// StoreNd, and StoreScatter are fixed (known before propagation). Purpose of
/// this analysis is to propagate those known layouts to all their producers and
/// (other) consumers.
class LayoutInfoPropagation
: public SparseBackwardDataFlowAnalysis<LayoutInfoLattice> {
private:
void visitDpasOp(xegpu::DpasOp dpas, ArrayRef<LayoutInfoLattice *> operands,
ArrayRef<const LayoutInfoLattice *> results);
void visitStoreNdOp(xegpu::StoreNdOp store,
ArrayRef<LayoutInfoLattice *> operands,
ArrayRef<const LayoutInfoLattice *> results);
void visitStoreScatterOp(xegpu::StoreScatterOp storeScatter,
ArrayRef<LayoutInfoLattice *> operands,
ArrayRef<const LayoutInfoLattice *> results);
void visitLoadNdOp(xegpu::LoadNdOp load,
ArrayRef<LayoutInfoLattice *> operands,
ArrayRef<const LayoutInfoLattice *> results);
void visitLoadGatherOp(xegpu::LoadGatherOp load,
ArrayRef<LayoutInfoLattice *> operands,
ArrayRef<const LayoutInfoLattice *> results);
void visitTransposeOp(vector::TransposeOp transpose,
ArrayRef<LayoutInfoLattice *> operands,
ArrayRef<const LayoutInfoLattice *> results);
void visitVectorBitcastOp(vector::BitCastOp bitcast,
ArrayRef<LayoutInfoLattice *> operands,
ArrayRef<const LayoutInfoLattice *> results);
void visitCreateDescOp(xegpu::CreateDescOp createDesc,
ArrayRef<LayoutInfoLattice *> operands,
ArrayRef<const LayoutInfoLattice *> results);
void visitUpdateNdOffsetOp(xegpu::UpdateNdOffsetOp updateNdOffset,
ArrayRef<LayoutInfoLattice *> operands,
ArrayRef<const LayoutInfoLattice *> results);
void visitPrefetchNdOp(xegpu::PrefetchNdOp prefetch,
ArrayRef<LayoutInfoLattice *> operands,
ArrayRef<const LayoutInfoLattice *> results);
void visitVectorMultiReductionOp(vector::MultiDimReductionOp reduction,
ArrayRef<LayoutInfoLattice *> operands,
ArrayRef<const LayoutInfoLattice *> results);
public:
LayoutInfoPropagation(DataFlowSolver &solver,
SymbolTableCollection &symbolTable)
: SparseBackwardDataFlowAnalysis(solver, symbolTable) {}
using SparseBackwardDataFlowAnalysis::SparseBackwardDataFlowAnalysis;
LogicalResult
visitOperation(Operation *op, ArrayRef<LayoutInfoLattice *> operands,
ArrayRef<const LayoutInfoLattice *> results) override;
void visitBranchOperand(OpOperand &operand) override {};
void visitCallOperand(OpOperand &operand) override {};
void visitExternalCall(CallOpInterface call,
ArrayRef<LayoutInfoLattice *> operands,
ArrayRef<const LayoutInfoLattice *> results) override {
};
void setToExitState(LayoutInfoLattice *lattice) override {
(void)lattice->meet(LayoutInfo());
}
};
} // namespace
LogicalResult LayoutInfoPropagation::visitOperation(
Operation *op, ArrayRef<LayoutInfoLattice *> operands,
ArrayRef<const LayoutInfoLattice *> results) {
TypeSwitch<Operation *>(op)
.Case<xegpu::DpasOp>(
[&](auto dpasOp) { visitDpasOp(dpasOp, operands, results); })
.Case<xegpu::StoreNdOp>(
[&](auto storeNdOp) { visitStoreNdOp(storeNdOp, operands, results); })
.Case<xegpu::StoreScatterOp>([&](auto storeScatterOp) {
visitStoreScatterOp(storeScatterOp, operands, results);
})
.Case<xegpu::LoadNdOp>(
[&](auto loadNdOp) { visitLoadNdOp(loadNdOp, operands, results); })
.Case<xegpu::LoadGatherOp>([&](auto loadGatherOp) {
visitLoadGatherOp(loadGatherOp, operands, results);
})
.Case<xegpu::CreateDescOp>([&](auto createDescOp) {
visitCreateDescOp(createDescOp, operands, results);
})
.Case<xegpu::UpdateNdOffsetOp>([&](auto updateNdOffsetOp) {
visitUpdateNdOffsetOp(updateNdOffsetOp, operands, results);
})
.Case<xegpu::PrefetchNdOp>([&](auto prefetchNdOp) {
visitPrefetchNdOp(prefetchNdOp, operands, results);
})
.Case<vector::TransposeOp>([&](auto transposeOp) {
visitTransposeOp(transposeOp, operands, results);
})
.Case<vector::BitCastOp>([&](auto bitcastOp) {
visitVectorBitcastOp(bitcastOp, operands, results);
})
.Case<vector::MultiDimReductionOp>([&](auto reductionOp) {
visitVectorMultiReductionOp(reductionOp, operands, results);
})
// All other ops.
.Default([&](Operation *op) {
for (const LayoutInfoLattice *resultInfo : results) {
if (!resultInfo->getValue().isAssigned())
continue;
for (auto [operandInfo, operand] :
llvm::zip(operands, op->getOpOperands())) {
// If the operand type is not a vector or tensor descriptor, skip
// it.
if (!isa<xegpu::TensorDescType, VectorType>(
operand.get().getType()))
continue;
// Propagate the result layout to the operand.
meet(operandInfo, *resultInfo);
}
}
});
return success();
}
void LayoutInfoPropagation::visitPrefetchNdOp(
xegpu::PrefetchNdOp prefetch, ArrayRef<LayoutInfoLattice *> operands,
ArrayRef<const LayoutInfoLattice *> results) {
// Here we assign the default layout to the tensor descriptor operand of
// prefetch.
auto tdescTy = prefetch.getTensorDescType();
auto prefetchLayout = getDefaultSIMTLayoutInfo(tdescTy);
// Propagate the layout to the source tensor descriptor.
propagateIfChanged(operands[0], operands[0]->meet(prefetchLayout));
}
void LayoutInfoPropagation::visitVectorMultiReductionOp(
vector::MultiDimReductionOp reduction,
ArrayRef<LayoutInfoLattice *> operands,
ArrayRef<const LayoutInfoLattice *> results) {
// The layout of the result must be present.
LayoutInfo resultLayout = results[0]->getValue();
if (!resultLayout.isAssigned())
return;
// We only consider 2D -> 1D reductions at this point.
VectorType resultTy = llvm::dyn_cast<VectorType>(reduction.getDestType());
if (!resultTy || resultTy.getRank() != 1) {
reduction.emitWarning("Expecting output type to be 1D vector.");
return;
}
// Given that the result is 1D, the layout of the operand should be 2D with
// default layout.
LayoutInfo operandLayout = getDefaultSIMTLayoutInfo(2);
propagateIfChanged(operands[0], operands[0]->meet(operandLayout));
// Accumulator should have the same layout as the result.
propagateIfChanged(operands[1], operands[1]->meet(resultLayout));
}
/// Propagate the layout of the result tensor to the source tensor descriptor in
/// UpdateNdOffsetOp.
void LayoutInfoPropagation::visitUpdateNdOffsetOp(
xegpu::UpdateNdOffsetOp updateNdOffset,
ArrayRef<LayoutInfoLattice *> operands,
ArrayRef<const LayoutInfoLattice *> results) {
// The layout of the result must be present.
LayoutInfo resultLayout = results[0]->getValue();
if (!resultLayout.isAssigned())
return;
// Propagate the layout to the source operand.
propagateIfChanged(operands[0], operands[0]->meet(resultLayout));
}
/// Set the layouts for DPAS A, B, and C operands.
void LayoutInfoPropagation::visitDpasOp(
xegpu::DpasOp dpas, ArrayRef<LayoutInfoLattice *> operands,
ArrayRef<const LayoutInfoLattice *> results) {
VectorType aTy = dpas.getLhsType();
VectorType bTy = dpas.getRhsType();
propagateIfChanged(
operands[0], operands[0]->meet(getSIMTLayoutInfoForDPASOperand(aTy, 0)));
propagateIfChanged(
operands[1], operands[1]->meet(getSIMTLayoutInfoForDPASOperand(bTy, 1)));
if (operands.size() > 2) {
VectorType cTy = dpas.getAccType();
propagateIfChanged(
operands[2],
operands[2]->meet(getSIMTLayoutInfoForDPASOperand(cTy, 2)));
}
}
/// Set the layout for the value and tensor descriptor operands in StoreNdOp.
void LayoutInfoPropagation::visitStoreNdOp(
xegpu::StoreNdOp store, ArrayRef<LayoutInfoLattice *> operands,
ArrayRef<const LayoutInfoLattice *> results) {
LayoutInfo storeLayout = getDefaultSIMTLayoutInfo(store.getValueType());
// Both operands should have the same layout
for (LayoutInfoLattice *operand : operands)
propagateIfChanged(operand, operand->meet(storeLayout));
}
/// Propagate the layout of the value to the tensor descriptor operand in
/// LoadNdOp.
void LayoutInfoPropagation::visitLoadNdOp(
xegpu::LoadNdOp load, ArrayRef<LayoutInfoLattice *> operands,
ArrayRef<const LayoutInfoLattice *> results) {
LayoutInfo valueLayout = results[0]->getValue();
// Need the layout of the value to propagate to the tensor descriptor.
if (!valueLayout.isAssigned())
return;
LayoutInfo tensorDescLayout = valueLayout;
// LoadNdOp has the transpose effect. However, at the stage of this analysis
// this effect is not expected and should be abstracted away. Emit a
// warning.
if (auto transpose = load.getTranspose()) {
load.emitWarning("Transpose effect is not expected for LoadNdOp at "
"LayoutInfoPropagation stage.");
tensorDescLayout = valueLayout.getTransposedLayout(transpose.value());
}
// Propagate the new layout to the tensor descriptor operand.
propagateIfChanged(operands[0], operands[0]->meet(tensorDescLayout));
}
/// For vector::TransposeOp, the layout of the result is transposed and
/// propagated to the operand.
void LayoutInfoPropagation::visitTransposeOp(
vector::TransposeOp transpose, ArrayRef<LayoutInfoLattice *> operands,
ArrayRef<const LayoutInfoLattice *> results) {
// Need the layout of transpose result to propagate to the operands.
LayoutInfo resultLayout = results[0]->getValue();
if (!resultLayout.isAssigned())
return;
LayoutInfo newLayout =
resultLayout.getTransposedLayout(transpose.getPermutation());
// Propagate the new layout to the vector operand.
propagateIfChanged(operands[0], operands[0]->meet(newLayout));
}
/// For vector::BitCastOp, the lane_data of the source layout is changed based
/// on the bit width of the source and result types.
void LayoutInfoPropagation::visitVectorBitcastOp(
vector::BitCastOp bitcast, ArrayRef<LayoutInfoLattice *> operands,
ArrayRef<const LayoutInfoLattice *> results) {
// Need the layout of bitcast result to propagate to the operands.
LayoutInfo resultLayout = results[0]->getValue();
if (!resultLayout.isAssigned())
return;
int inElemTyBitWidth =
bitcast.getSourceVectorType().getElementType().getIntOrFloatBitWidth();
int outElemTyBitWidth =
bitcast.getResultVectorType().getElementType().getIntOrFloatBitWidth();
// NOTE: We do not expect widening or narrowing bitcasts at this stage. Emit
// a warning and return.
if (inElemTyBitWidth != outElemTyBitWidth) {
bitcast.emitWarning("Widening or narrowing bitcasts are not expected at "
"layout propagation stage.");
return;
}
propagateIfChanged(operands[0], operands[0]->meet(resultLayout));
}
/// Propagate the layout of the result to the tensor descriptor, mask and offset
/// operands in LoadGatherOp.
void LayoutInfoPropagation::visitLoadGatherOp(
xegpu::LoadGatherOp load, ArrayRef<LayoutInfoLattice *> operands,
ArrayRef<const LayoutInfoLattice *> results) {
// The layout is strictly determined by the payload type.
auto payloadTy = dyn_cast<VectorType>(load.getValueType());
if (!payloadTy) {
load.emitWarning("Not propagating, non-vector payload supplied.");
return;
}
LayoutInfo layout = getDefaultSIMTLayoutInfo(payloadTy, /*scattered*/ true);
// Mask operand should have 1D default layout.
LayoutInfo maskLayout = getDefaultSIMTLayoutInfo(1);
// Propagate the new layout to the tensor descriptor operand.
if (isa<xegpu::TensorDescType>(load.getSourceType()))
propagateIfChanged(operands[0], operands[0]->meet(layout));
// Propagate the new layout to the mask and optional offset operand.
propagateIfChanged(operands[1], operands[1]->meet(maskLayout));
if (load.getOffsets())
propagateIfChanged(operands[2], operands[2]->meet(maskLayout));
}
/// Propagate the layout of the descriptor to the vector offset operand in
/// CreateDescOp.
void LayoutInfoPropagation::visitCreateDescOp(
xegpu::CreateDescOp createDesc, ArrayRef<LayoutInfoLattice *> operands,
ArrayRef<const LayoutInfoLattice *> results) {
LayoutInfo descLayout = results[0]->getValue();
// Need the layout of the descriptor to propagate to the operands.
if (!descLayout.isAssigned())
return;
// For offset operand propagate 1D default layout.
LayoutInfo layout = getDefaultSIMTLayoutInfo(1);
propagateIfChanged(operands[1], operands[1]->meet(layout));
}
/// Set the layout for the value, tensor descriptor, offset and mask operands in
/// the StoreScatterOp.
void LayoutInfoPropagation::visitStoreScatterOp(
xegpu::StoreScatterOp storeScatter, ArrayRef<LayoutInfoLattice *> operands,
ArrayRef<const LayoutInfoLattice *> results) {
// Currently, for 2D StoreScatterOp we expect that the height dimension of
// the tensor descriptor is equal to the subgroup size. This is ensured by
// the op verifier.
auto payloadTy = dyn_cast<VectorType>(storeScatter.getValueType());
if (!payloadTy) {
storeScatter.emitWarning("Not propagating, non-vector payload supplied.");
return;
}
auto payloadShape = payloadTy.getShape();
if (payloadShape.size() > 1)
assert(
payloadShape[0] == xegpu::targetinfo::subgroupSize &&
"Expected the first dimension of 2D tensor descriptor to be equal to "
"subgroup size.");
LayoutInfo payloadLayout =
getDefaultSIMTLayoutInfo(payloadTy, /*scattered=*/true);
LayoutInfo maskLayout = getDefaultSIMTLayoutInfo(1);
// Propagate the payload operand layout
propagateIfChanged(operands[0], operands[0]->meet(payloadLayout));
// Propagate the destination (if tdesc) operand layout
if (isa<xegpu::TensorDescType>(storeScatter.getDestType()))
propagateIfChanged(operands[1], operands[1]->meet(payloadLayout));
// Propagate the new layout to the mask and optional offset operand.
propagateIfChanged(operands[2], operands[2]->meet(maskLayout));
if (storeScatter.getOffsets())
propagateIfChanged(operands[3], operands[3]->meet(maskLayout));
}
namespace {
//===----------------------------------------------------------------------===//
// RunLayoutInfoPropagation
//===----------------------------------------------------------------------===//
/// Driver class for running the LayoutInfoPropagation analysis.
class RunLayoutInfoPropagation {
public:
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(RunLayoutInfoPropagation)
RunLayoutInfoPropagation(Operation *op) : target(op) {
SymbolTableCollection symbolTable;
loadBaselineAnalyses(solver);
solver.load<LayoutInfoPropagation>(symbolTable);
(void)solver.initializeAndRun(op);
}
LayoutInfo getLayoutInfo(Value val);
void printAnalysisResult(llvm::raw_ostream &os);
private:
DataFlowSolver solver;
const Operation *target;
};
} // namespace
LayoutInfo RunLayoutInfoPropagation::getLayoutInfo(Value val) {
auto *state = solver.lookupState<LayoutInfoLattice>(val);
if (!state)
return {};
return state->getValue();
}
// Print the analysis result for debugging purposes.
void RunLayoutInfoPropagation::printAnalysisResult(llvm::raw_ostream &os) {
auto printFunctionResult = [&](FunctionOpInterface funcOp) {
os << "function: " << funcOp.getName() << ":\n";
// Function arguments
for (BlockArgument arg : funcOp.getArguments()) {
LayoutInfo layout = getLayoutInfo(arg);
os << "argument: " << arg << "\n";
os << "layout : ";
layout.print(os);
os << "\n";
}
// Function ops
funcOp.walk([&](Operation *op) {
// Skip ops that do not have results
if (op->getResults().empty())
return;
os << "op : ";
// For control-flow ops, print the op name only.
if (isa<BranchOpInterface>(op) || isa<RegionBranchOpInterface>(op))
os << op->getName();
else
op->print(os);
os << "\n";
// Print the layout for each result.
for (auto [i, r] : llvm::enumerate(op->getResults())) {
LayoutInfo layout = getLayoutInfo(r);
os << "layout for result #" << i << ": ";
layout.print(os);
os << "\n";
}
});
};
SmallVector<FunctionOpInterface> funcOps;
if (auto modOp = dyn_cast<ModuleOp>(target)) {
for (auto funcOp : modOp.getOps<FunctionOpInterface>())
funcOps.push_back(funcOp);
// Collect all GpuFuncOps in the module.
for (auto gpuModOp : modOp.getOps<gpu::GPUModuleOp>()) {
for (auto gpuFuncOp : gpuModOp.getOps<FunctionOpInterface>())
funcOps.push_back(gpuFuncOp);
}
}
// Print the analysis result for each function.
for (FunctionOpInterface funcOp : funcOps)
printFunctionResult(funcOp);
}
using GetLayoutFnTy = function_ref<xegpu::LayoutAttr(Value)>;
/// Update an operation with the layout of its results. If the result type is a
/// vector type, a temporary layout attribute is added to the operation. If the
/// result type is a tensor descriptor type, the type is updated with the layout
/// attribute. The users of the result are also updated with the layout
/// attribute.
static LogicalResult updateOp(mlir::OpBuilder &builder, mlir::Operation *op,
GetLayoutFnTy getLayoutOfValue) {
// Region ops (like scf.for) are already handled by the updateControlFlowOps.
if (mlir::isa<mlir::RegionBranchOpInterface>(op))
return success();
// Iterate over all the results.
for (OpResult result : op->getResults()) {
Type resultType = result.getType();
// Layouts are needed only for vector and tensor descriptor types.
if (!isa<VectorType, xegpu::TensorDescType>(resultType))
continue;
// If the result has no layout but has users, emit a warning and continue.
xegpu::LayoutAttr layout = getLayoutOfValue(result);
if (!layout && result.getNumUses() > 0) {
op->emitWarning("op has users but no layout assigned for its result");
continue;
}
// If the result is a tensor descriptor type, update the tensor desc type
// with layout.
if (auto tensorDescTy = dyn_cast<xegpu::TensorDescType>(resultType)) {
auto typeWithLayout = xegpu::TensorDescType::get(
tensorDescTy.getContext(), tensorDescTy.getShape(),
tensorDescTy.getElementType(), tensorDescTy.getEncoding(), layout);
result.setType(typeWithLayout);
continue;
}
// If the result is a vector type, add a temporary layout attribute to the
// op.
xegpu::setDistributeLayoutAttr(result, layout);
}
return success();
}
/// Region ops like scf.for need special handling because they have blocks
/// inside. If the blocks have tensor descriptor type as block arguments, thier
/// types must be updated. Also region op can have results that may not have any
/// users (e.g. A and B tiles). They are not assigned a layout by layout
/// analysis because they have no users. However inside the region op
/// corresponding block arguments for these results do have layouts. Therefore,
/// in this case we still need to update the result types with the layout
/// attribute. This function function updates the internal block arguments and
/// the result types of the region op with the assigned layouts.
/// clang-format off
/// Example: scf.for ... iter_args(...) -> (out types) {
/// ^bb0(block types):
/// ...
/// scf.yield ... : (yield types)
/// }
/// clang-format on
/// In this example, at scf.yield, control-flow can transfer to two successor
/// regions. One is the ^bb0 (for loop body) and the other is the scf.for op
/// itself (yield the results). So we update both the block arguments of the
/// successor region (i.e. block types) and the result types of the scf.for op
/// (i.e. out types). Note that yield types are updated by respective producers
/// inside bb0.
static LogicalResult
updateControlFlowOps(mlir::OpBuilder &builder,
mlir::RegionBranchTerminatorOpInterface terminator,
GetLayoutFnTy getLayoutOfValue) {
// Only process if the terminator is inside a region branch op.
if (!mlir::isa<mlir::RegionBranchOpInterface>(terminator->getParentOp()))
return success();
llvm::SmallVector<mlir::RegionSuccessor> successors;
llvm::SmallVector<mlir::Attribute> operands(terminator->getNumOperands(),
nullptr);
terminator.getSuccessorRegions(operands, successors);
for (mlir::RegionSuccessor &successor : successors) {
mlir::OperandRange successorOperands =
terminator.getSuccessorOperands(successor);
mlir::ValueRange successorInputs = successor.getSuccessorInputs();
for (auto [successorOperand, successorInput] :
llvm::zip(successorOperands, successorInputs)) {
Type inputType = successorInput.getType();
// We only need to operate on tensor descriptor or vector types.
if (!isa<xegpu::TensorDescType, VectorType>(inputType))
continue;
xegpu::LayoutAttr successorInputLayout = getLayoutOfValue(successorInput);
xegpu::LayoutAttr successorOperandLayout =
getLayoutOfValue(successorOperand);
// If either of the layouts is not assigned, we cannot proceed.
if (!successorOperandLayout) {
LLVM_DEBUG(
DBGS()
<< "No layout assigned for forwarded operand in branch terminator: "
<< successorOperand << "\n");
return failure();
}
// We expect the layouts to match.
if (successorInputLayout &&
successorInputLayout != successorOperandLayout) {
LLVM_DEBUG(DBGS() << "Conflicting layouts for region argument and "
"operand forwarded as the argument: "
<< successorInputLayout << " vs "
<< successorOperandLayout << "\n");
return failure();
}
// Get tensor descriptor type with the layout.
if (auto tdescTy = dyn_cast<xegpu::TensorDescType>(inputType)) {
auto newTdescTy = xegpu::TensorDescType::get(
tdescTy.getContext(), tdescTy.getShape(), tdescTy.getElementType(),
tdescTy.getEncoding(), successorOperandLayout);
successorInput.setType(newTdescTy);
continue;
}
// If the type is a vector type and this region argument is an OpResult,
// set the layout attribute on the OpResult.
if (auto result = dyn_cast<OpResult>(successorInput))
xegpu::setDistributeLayoutAttr(result, successorOperandLayout);
}
}
return success();
}
/// Update the function arguments and results with the layouts.
static LogicalResult updateFunctionOpInterface(mlir::OpBuilder &builder,
mlir::FunctionOpInterface funcOp,
GetLayoutFnTy getLayoutOfValue) {
SmallVector<Type> newArgTypes;
// Update the function arguments.
for (BlockArgument arg : funcOp.getArguments()) {
Type argType = arg.getType();
newArgTypes.push_back(argType);
if (!isa<VectorType, xegpu::TensorDescType>(argType))
continue;
xegpu::LayoutAttr layout = getLayoutOfValue(arg);
if (!layout) {
LLVM_DEBUG(DBGS() << "Expecting layout for function argument: " << arg
<< " but got none.\n");
return failure();
}
if (auto tensorDescTy = dyn_cast<xegpu::TensorDescType>(argType)) {
auto newTdescTy = xegpu::TensorDescType::get(
tensorDescTy.getContext(), tensorDescTy.getShape(),
tensorDescTy.getElementType(), tensorDescTy.getEncoding(), layout);
arg.setType(newTdescTy);
newArgTypes.back() = newTdescTy;
}
}
// Update the function type with the new argument types.
// NOTE: We assume that function results are not expected to have layouts.
funcOp.setType(FunctionType::get(funcOp.getContext(), newArgTypes,
funcOp.getResultTypes()));
return success();
}
namespace {
struct XeGPUPropagateLayoutPass final
: public xegpu::impl::XeGPUPropagateLayoutBase<XeGPUPropagateLayoutPass> {
XeGPUPropagateLayoutPass() = default;
XeGPUPropagateLayoutPass(const XeGPUPropagateLayoutPass &other) = default;
XeGPUPropagateLayoutPass(xegpu::XeGPUPropagateLayoutOptions options)
: XeGPUPropagateLayoutBase(options) {}
void runOnOperation() override;
};
} // namespace
void XeGPUPropagateLayoutPass::runOnOperation() {
auto &analysis = getAnalysis<RunLayoutInfoPropagation>();
// Print the analysis result and exit. (for debugging purposes)
if (printOnly) {
auto &os = llvm::outs();
analysis.printAnalysisResult(os);
return;
}
// Helper to convert LayoutInfo to xegpu::LayoutAttr.
auto getXeGPULayoutForValue = [&](Value val) -> xegpu::LayoutAttr {
LayoutInfo layout = analysis.getLayoutInfo(val);
if (!layout.isAssigned())
return {};
return xegpu::LayoutAttr::get(
val.getContext(), llvm::to_vector_of<int>(layout.getLayoutAsArrayRef()),
llvm::to_vector_of<int>(layout.getDataAsArrayRef()));
};
mlir::OpBuilder builder(&getContext());
Operation *op = getOperation();
auto walkResult = op->walk([&](mlir::Block *block) -> WalkResult {
for (mlir::Operation &op : llvm::reverse(block->getOperations())) {
LogicalResult r = success();
TypeSwitch<Operation *>(&op)
.Case<mlir::RegionBranchTerminatorOpInterface>(
[&](mlir::RegionBranchTerminatorOpInterface branchTermOp) {
r = updateControlFlowOps(builder, branchTermOp,
getXeGPULayoutForValue);
})
.Case<mlir::FunctionOpInterface>(
[&](mlir::FunctionOpInterface funcOp) {
r = updateFunctionOpInterface(builder, funcOp,
getXeGPULayoutForValue);
})
.Default([&](Operation *op) {
r = updateOp(builder, op, getXeGPULayoutForValue);
});
if (failed(r)) {
op.emitError("Failed to update operation with the layout.");
return WalkResult::interrupt();
}
}
return WalkResult::advance();
});
if (walkResult.wasInterrupted()) {
signalPassFailure();
return;
}
}