| //===- XeGPUTransformOps.cpp - Implementation of XeGPU transformation ops -===// |
| // |
| // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| // See https://llvm.org/LICENSE.txt for license information. |
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| // |
| //===----------------------------------------------------------------------===// |
| |
| #include "mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.h" |
| #include "mlir/Dialect/GPU/IR/GPUDialect.h" |
| #include "mlir/Dialect/SCF/IR/SCF.h" |
| #include "mlir/Dialect/SCF/Utils/Utils.h" |
| #include "mlir/Dialect/XeGPU/IR/XeGPU.h" |
| #include "mlir/Dialect/XeGPU/Utils/XeGPUUtils.h" |
| #include "llvm/ADT/SmallVectorExtras.h" |
| |
| #include <optional> |
| |
| #include "llvm/Support/DebugLog.h" |
| #define DEBUG_TYPE "xegpu-transforms" |
| |
| using namespace mlir; |
| using namespace mlir::transform; |
| |
| /// Assuming that `ofr` is an index attr or a param of index type |
| /// or a transform dialect handle mapped to exactly one op |
| /// with one index result, get that value and cast it to int type. |
| static DiagnosedSilenceableFailure convertMixedValuesToInt( |
| transform::TransformState &state, TransformOpInterface transformOp, |
| SmallVectorImpl<int32_t> &result, ArrayRef<OpFoldResult> ofrs) { |
| for (OpFoldResult ofr : ofrs) { |
| // Attribute case. |
| if (auto attr = dyn_cast<Attribute>(ofr)) { |
| if (auto intAttr = dyn_cast<IntegerAttr>(attr)) { |
| result.push_back(intAttr.getInt()); |
| continue; |
| } |
| return transformOp.emitDefiniteFailure() << "expected IntegerAttr"; |
| } |
| |
| // Transform param case. |
| Value transformValue = cast<Value>(ofr); |
| if (isa<TransformParamTypeInterface>(transformValue.getType())) { |
| ArrayRef<Attribute> params = state.getParams(transformValue); |
| if (params.size() != 1) |
| return transformOp.emitDefiniteFailure() |
| << "requires exactly one parameter associated"; |
| result.push_back( |
| cast<IntegerAttr>(params.front()).getValue().getSExtValue()); |
| continue; |
| } |
| |
| // Payload value case. |
| auto payloadOps = state.getPayloadOps(transformValue); |
| if (!llvm::hasSingleElement(payloadOps)) { |
| DiagnosedSilenceableFailure diag = |
| transformOp.emitSilenceableError() |
| << "handle must be mapped to exactly one payload op"; |
| diag.attachNote(transformValue.getLoc()) |
| << "mapped to " << llvm::range_size(payloadOps) << " payload ops"; |
| return diag; |
| } |
| |
| Operation *op = *payloadOps.begin(); |
| if (op->getNumResults() != 1 || !op->getResult(0).getType().isIndex()) { |
| DiagnosedSilenceableFailure diag = |
| transformOp.emitSilenceableError() |
| << "payload op must have exactly 1 index result"; |
| diag.attachNote(op->getLoc()) |
| << "has " << op->getNumResults() << " results"; |
| return diag; |
| } |
| |
| IntegerAttr intAttr; |
| if (!matchPattern(op->getResult(0), m_Constant(&intAttr))) |
| return transformOp.emitSilenceableError() |
| << "requires param or handle to be the result of a constant like " |
| "op"; |
| |
| result.push_back(intAttr.getInt()); |
| } |
| return DiagnosedSilenceableFailure::success(); |
| } |
| |
| /// Find producer operation of type T for the given value. |
| /// It's assumed that producer ops are chained through their first operand. |
| /// Producer chain is traced trough loop block arguments (init values). |
| template <typename T> |
| static std::optional<T> findProducerOfType(Value val) { |
| Value currentValue = val; |
| if (!currentValue.getDefiningOp()) { |
| // Value may be a block argument initialized outside a loop. |
| if (val.getNumUses() == 0) { |
| LDBG() << "Failed to find producer op, value has no uses."; |
| return std::nullopt; |
| } |
| auto userOp = val.getUsers().begin(); |
| auto parentLoop = userOp->getParentOfType<LoopLikeOpInterface>(); |
| if (!parentLoop) { |
| LDBG() << "Failed to find producer op, not in a loop."; |
| return std::nullopt; |
| } |
| int64_t iterArgIdx; |
| if (auto iterArg = llvm::dyn_cast<BlockArgument>(currentValue)) { |
| auto numInductionVars = parentLoop.getLoopInductionVars()->size(); |
| iterArgIdx = iterArg.getArgNumber() - numInductionVars; |
| currentValue = parentLoop.getInits()[iterArgIdx]; |
| } else { |
| LDBG() << "Failed to find producer op, value not in init values."; |
| return std::nullopt; |
| } |
| } |
| Operation *producerOp = currentValue.getDefiningOp(); |
| |
| if (auto matchingOp = dyn_cast<T>(producerOp)) |
| return matchingOp; |
| |
| if (producerOp->getNumOperands() == 0) |
| return std::nullopt; |
| |
| return findProducerOfType<T>(producerOp->getOperand(0)); |
| } |
| |
| /// Create a layout attribute from the given parameters. |
| static xegpu::LayoutAttr createLayoutAttr( |
| MLIRContext *ctx, ArrayRef<int32_t> sgLayout, ArrayRef<int32_t> sgData, |
| std::optional<ArrayRef<int32_t>> instData, ArrayRef<int32_t> order) { |
| return xegpu::LayoutAttr::get( |
| ctx, DenseI32ArrayAttr::get(ctx, sgLayout), |
| DenseI32ArrayAttr::get(ctx, sgData), |
| instData ? DenseI32ArrayAttr::get(ctx, instData.value()) : nullptr, |
| /*lane_layout=*/nullptr, |
| /*lane_data=*/nullptr, |
| /*order=*/order.empty() ? nullptr : DenseI32ArrayAttr::get(ctx, order)); |
| } |
| |
| /// Generate `xegpu::LayoutAttr` from op mixed layout values. |
| DiagnosedSilenceableFailure |
| getLayoutAttrFromOperands(MLIRContext *ctx, transform::TransformState &state, |
| TransformOpInterface transformOp, |
| ArrayRef<::mlir::OpFoldResult> mixedSgLayout, |
| ArrayRef<::mlir::OpFoldResult> mixedSgData, |
| ArrayRef<::mlir::OpFoldResult> mixedInstData, |
| ArrayRef<int32_t> order, |
| xegpu::LayoutAttr &layoutAttr) { |
| SmallVector<int32_t> sgLayout, sgData, instData; |
| auto status = |
| convertMixedValuesToInt(state, transformOp, sgLayout, mixedSgLayout); |
| if (!status.succeeded()) |
| return status; |
| |
| status = convertMixedValuesToInt(state, transformOp, sgData, mixedSgData); |
| if (!status.succeeded()) |
| return status; |
| |
| status = convertMixedValuesToInt(state, transformOp, instData, mixedInstData); |
| if (!status.succeeded()) |
| return status; |
| auto maybeInstData = instData.empty() |
| ? std::nullopt |
| : std::optional<ArrayRef<int32_t>>(instData); |
| |
| layoutAttr = createLayoutAttr(ctx, sgLayout, sgData, maybeInstData, order); |
| |
| return DiagnosedSilenceableFailure::success(); |
| } |
| |
| /// Replace xegpu.create_nd_desc op with a new one with the given layout. |
| static xegpu::CreateNdDescOp |
| setDescLayout(transform::TransformRewriter &rewriter, |
| xegpu::CreateNdDescOp descOp, |
| xegpu::DistributeLayoutAttr layout) { |
| assert(descOp.getMixedOffsets().size() == 0 && |
| "create desc op with offsets is not supported"); |
| auto oldTensorDesc = descOp.getType(); |
| auto descType = xegpu::TensorDescType::get( |
| oldTensorDesc.getShape(), oldTensorDesc.getElementType(), |
| /*array_length=*/oldTensorDesc.getArrayLength(), |
| /*boundary_check=*/oldTensorDesc.getBoundaryCheck(), |
| /*memory_space=*/oldTensorDesc.getMemorySpace(), |
| /*layout=*/layout); |
| |
| rewriter.setInsertionPointAfter(descOp); |
| auto newDescOp = rewriter.replaceOpWithNewOp<xegpu::CreateNdDescOp>( |
| descOp, descType, descOp.getSource(), descOp.getMixedSizes(), |
| descOp.getMixedStrides()); |
| return newDescOp; |
| } |
| |
| DiagnosedSilenceableFailure |
| transform::GetDescOp::apply(transform::TransformRewriter &rewriter, |
| transform::TransformResults &results, |
| transform::TransformState &state) { |
| auto targetValues = state.getPayloadValues(getTarget()); |
| if (!llvm::hasSingleElement(targetValues)) { |
| return emitDefiniteFailure() |
| << "requires exactly one target value handle (got " |
| << llvm::range_size(targetValues) << ")"; |
| } |
| |
| auto maybeDescOp = |
| findProducerOfType<xegpu::CreateNdDescOp>(*targetValues.begin()); |
| if (!maybeDescOp) { |
| return emitSilenceableFailure(getLoc()) |
| << "Could not find a matching descriptor op when walking the " |
| "producer chain of the first operand."; |
| } |
| |
| results.set(llvm::cast<OpResult>(getResult()), {*maybeDescOp}); |
| return DiagnosedSilenceableFailure::success(); |
| } |
| |
| void transform::SetDescLayoutOp::build(OpBuilder &builder, |
| OperationState &result, Value target, |
| ArrayRef<OpFoldResult> mixedSgLayout, |
| ArrayRef<OpFoldResult> mixedSgData, |
| ArrayRef<OpFoldResult> mixedInstData, |
| ArrayRef<int32_t> order, |
| ArrayRef<int64_t> sliceDims) { |
| SmallVector<int64_t> staticSgLayout, staticSgData, staticInstData; |
| SmallVector<Value> dynamicSgLayout, dynamicSgData, dynamicInstData; |
| dispatchIndexOpFoldResults(mixedSgLayout, dynamicSgLayout, staticSgLayout); |
| dispatchIndexOpFoldResults(mixedSgData, dynamicSgData, staticSgData); |
| dispatchIndexOpFoldResults(mixedInstData, dynamicInstData, staticInstData); |
| build(builder, result, target.getType(), |
| /*target=*/target, |
| /*sg_layout=*/dynamicSgLayout, |
| /*sg_data=*/dynamicSgData, |
| /*inst_data=*/dynamicInstData, |
| /*static_sg_layout=*/staticSgLayout, |
| /*static_sg_data=*/staticSgData, |
| /*static_inst_data=*/staticInstData, |
| /*order=*/order, |
| /*slice_dims=*/sliceDims); |
| } |
| |
| DiagnosedSilenceableFailure |
| transform::SetDescLayoutOp::apply(transform::TransformRewriter &rewriter, |
| transform::TransformResults &results, |
| transform::TransformState &state) { |
| auto targetOps = state.getPayloadOps(getTarget()); |
| if (!llvm::hasSingleElement(targetOps)) { |
| return emitDefiniteFailure() << "requires exactly one targetOp handle (got " |
| << llvm::range_size(targetOps) << ")"; |
| } |
| Operation *target = *targetOps.begin(); |
| |
| xegpu::LayoutAttr layoutAttr = nullptr; |
| auto status = getLayoutAttrFromOperands( |
| getContext(), state, (*this), getMixedSgLayout(), getMixedSgData(), |
| getMixedInstData(), getOrder(), layoutAttr); |
| if (!status.succeeded()) |
| return status; |
| |
| xegpu::DistributeLayoutAttr layout = layoutAttr; |
| auto sliceDims = getSliceDims(); |
| if (sliceDims.size() > 0) { |
| // Wrap layoutAttr in a slice attribute. |
| layout = xegpu::SliceAttr::get( |
| getContext(), layout, DenseI64ArrayAttr::get(getContext(), sliceDims)); |
| } |
| |
| // For now only create_nd_desc op is supported. |
| auto descOp = dyn_cast<xegpu::CreateNdDescOp>(target); |
| if (!descOp) { |
| auto diag = emitSilenceableFailure(getLoc()) |
| << "Expected a xegpu.create_nd_desc op, but got: " |
| << target->getName(); |
| diag.attachNote(target->getLoc()) << "target op"; |
| return diag; |
| } |
| |
| // Set layout attr in desc op's return type. Replaces old desc op. |
| auto newdescOp = setDescLayout(rewriter, descOp, layout); |
| |
| // Map result handles. |
| results.set(cast<OpResult>(getTransformed()), {newdescOp.getOperation()}); |
| |
| return DiagnosedSilenceableFailure::success(); |
| } |
| |
| void transform::SetDescLayoutOp::getEffects( |
| ::llvm::SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { |
| consumesHandle(getTargetMutable(), effects); |
| onlyReadsHandle(getSgLayoutMutable(), effects); |
| onlyReadsHandle(getSgDataMutable(), effects); |
| onlyReadsHandle(getInstDataMutable(), effects); |
| producesHandle(getOperation()->getOpResults(), effects); |
| modifiesPayload(effects); |
| } |
| |
| void transform::SetOpLayoutAttrOp::build( |
| OpBuilder &builder, OperationState &ostate, Value target, int64_t index, |
| ArrayRef<OpFoldResult> mixedSgLayout, ArrayRef<OpFoldResult> mixedSgData, |
| ArrayRef<OpFoldResult> mixedInstData, ArrayRef<int32_t> order, |
| ArrayRef<int64_t> sliceDims, bool result, bool operand) { |
| SmallVector<int64_t> staticSgLayout, staticSgData, staticInstData; |
| SmallVector<Value> dynamicSgLayout, dynamicSgData, dynamicInstData; |
| dispatchIndexOpFoldResults(mixedSgLayout, dynamicSgLayout, staticSgLayout); |
| dispatchIndexOpFoldResults(mixedSgData, dynamicSgData, staticSgData); |
| dispatchIndexOpFoldResults(mixedInstData, dynamicInstData, staticInstData); |
| build(builder, ostate, target.getType(), |
| /*target=*/target, |
| /*index=*/index, |
| /*sg_layout=*/dynamicSgLayout, |
| /*sg_data=*/dynamicSgData, |
| /*inst_data=*/dynamicInstData, |
| /*static_sg_layout=*/staticSgLayout, |
| /*static_sg_data=*/staticSgData, |
| /*static_inst_data=*/staticInstData, |
| /*order=*/order, |
| /*slice_dims=*/sliceDims, |
| /*result=*/result, |
| /*operand=*/operand); |
| } |
| |
| DiagnosedSilenceableFailure |
| transform::SetOpLayoutAttrOp::apply(transform::TransformRewriter &rewriter, |
| transform::TransformResults &results, |
| transform::TransformState &state) { |
| auto targetOps = state.getPayloadOps(getTarget()); |
| if (!llvm::hasSingleElement(targetOps)) { |
| return emitDefiniteFailure() << "Requires exactly one targetOp handle (got " |
| << llvm::range_size(targetOps) << ")"; |
| } |
| Operation *target = *targetOps.begin(); |
| |
| bool resultTarget = getResult(); |
| bool operandTarget = getOperand(); |
| |
| int64_t index = getIndex(); |
| if (resultTarget && index >= target->getNumResults()) { |
| return emitSilenceableFailure(getLoc()) |
| << "Index exceeds the number of op results"; |
| } |
| if (operandTarget && index >= target->getNumOperands()) { |
| return emitSilenceableFailure(getLoc()) |
| << "Index exceeds the number of op operands"; |
| } |
| |
| xegpu::LayoutAttr layoutAttr = nullptr; |
| auto status = getLayoutAttrFromOperands( |
| getContext(), state, (*this), getMixedSgLayout(), getMixedSgData(), |
| getMixedInstData(), getOrder(), layoutAttr); |
| if (!status.succeeded()) |
| return status; |
| |
| xegpu::DistributeLayoutAttr layout = layoutAttr; |
| auto sliceDims = getSliceDims(); |
| if (sliceDims.size() > 0) { |
| // Wrap layoutAttr in a slice attribute. |
| layout = xegpu::SliceAttr::get( |
| getContext(), layout, DenseI64ArrayAttr::get(getContext(), sliceDims)); |
| } |
| |
| // Set layout attribute |
| if (resultTarget) { |
| // op result |
| xegpu::setDistributeLayoutAttr(target->getResult(index), layout); |
| } else if (operandTarget) { |
| // op operand |
| xegpu::setDistributeLayoutAttr(target->getOpOperand(index), layout); |
| } else if (auto dpasOp = dyn_cast<xegpu::DpasOp>(target)) { |
| // dpas op is a special case where layout needs to be set for A, B, and C |
| if (index == 0) |
| dpasOp.getProperties().layout_a = layout; |
| else if (index == 1) |
| dpasOp.getProperties().layout_b = layout; |
| else if (index == 2) |
| dpasOp.getProperties().layout_cd = layout; |
| else { |
| auto diag = emitSilenceableFailure(getLoc()) |
| << "Invalid index for setting dpas op layout: " << index; |
| diag.attachNote(target->getLoc()) << "target op"; |
| return diag; |
| } |
| } else { |
| // op's anchor layout. |
| auto anchorOp = dyn_cast<xegpu::AnchorLayoutInterface>(target); |
| if (!anchorOp) { |
| auto diag = emitSilenceableFailure(getLoc()) |
| << "Cannot set anchor layout to op: " << target->getName(); |
| diag.attachNote(target->getLoc()) << "target op"; |
| return diag; |
| } |
| anchorOp.setAnchorLayout(layout); |
| } |
| return DiagnosedSilenceableFailure::success(); |
| } |
| |
| void transform::SetOpLayoutAttrOp::getEffects( |
| ::llvm::SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { |
| onlyReadsHandle(getTargetMutable(), effects); |
| onlyReadsHandle(getSgLayoutMutable(), effects); |
| onlyReadsHandle(getSgDataMutable(), effects); |
| onlyReadsHandle(getInstDataMutable(), effects); |
| modifiesPayload(effects); |
| } |
| |
| LogicalResult transform::SetOpLayoutAttrOp::verify() { |
| if (getResult() && getOperand()) { |
| return emitOpError("Cannot set both result and operand simultaneously."); |
| } |
| return success(); |
| } |
| |
| void transform::SetGPULaunchThreadsOp::build( |
| OpBuilder &builder, OperationState &ostate, Value target, |
| ArrayRef<OpFoldResult> mixedThreads) { |
| SmallVector<int64_t> staticThreads; |
| SmallVector<Value> dynamicThreads; |
| dispatchIndexOpFoldResults(mixedThreads, dynamicThreads, staticThreads); |
| build(builder, ostate, target.getType(), |
| /*target=*/target, |
| /*threads=*/dynamicThreads, |
| /*static_threads=*/staticThreads); |
| } |
| |
| DiagnosedSilenceableFailure |
| transform::SetGPULaunchThreadsOp::apply(transform::TransformRewriter &rewriter, |
| transform::TransformResults &results, |
| transform::TransformState &state) { |
| auto targetOps = state.getPayloadOps(getTarget()); |
| if (!llvm::hasSingleElement(targetOps)) { |
| return emitDefiniteFailure() << "Requires exactly one targetOp handle (got " |
| << llvm::range_size(targetOps) << ")"; |
| } |
| Operation *target = *targetOps.begin(); |
| |
| auto launchOp = dyn_cast<gpu::LaunchOp>(target); |
| if (!launchOp) { |
| auto diag = emitSilenceableFailure(getLoc()) |
| << "Expected a gpu.launch op, but got: " << target->getName(); |
| diag.attachNote(target->getLoc()) << "target op"; |
| return diag; |
| } |
| |
| SmallVector<int32_t> threads; |
| DiagnosedSilenceableFailure status = |
| convertMixedValuesToInt(state, (*this), threads, getMixedThreads()); |
| if (!status.succeeded()) |
| return status; |
| |
| if (threads.size() != 3) { |
| return emitSilenceableFailure(getLoc()) |
| << "Expected threads argument to consist of three values (got " |
| << threads.size() << ")"; |
| } |
| |
| rewriter.setInsertionPoint(launchOp); |
| auto createConstValue = [&](int value) { |
| return arith::ConstantIndexOp::create(rewriter, launchOp.getLoc(), value); |
| }; |
| |
| // Replace threads in-place. |
| launchOp.getBlockSizeXMutable().assign(createConstValue(threads[0])); |
| launchOp.getBlockSizeYMutable().assign(createConstValue(threads[1])); |
| launchOp.getBlockSizeZMutable().assign(createConstValue(threads[2])); |
| |
| return DiagnosedSilenceableFailure::success(); |
| } |
| |
| void transform::SetGPULaunchThreadsOp::getEffects( |
| ::llvm::SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { |
| onlyReadsHandle(getTargetMutable(), effects); |
| onlyReadsHandle(getThreadsMutable(), effects); |
| modifiesPayload(effects); |
| } |
| |
| DiagnosedSilenceableFailure |
| transform::InsertPrefetchOp::apply(transform::TransformRewriter &rewriter, |
| transform::TransformResults &results, |
| transform::TransformState &state) { |
| auto targetValues = state.getPayloadValues(getTarget()); |
| if (!llvm::hasSingleElement(targetValues)) |
| return emitDefiniteFailure() |
| << "requires exactly one target value handle (got " |
| << llvm::range_size(targetValues) << ")"; |
| auto value = *targetValues.begin(); |
| |
| int64_t nbPrefetch = getStaticNbPrefetch(); |
| if (getDynamicNbPrefetch()) { |
| // Get dynamic prefetch count from transform param or handle. |
| SmallVector<int32_t> dynamicNbPrefetch; |
| auto status = convertMixedValuesToInt(state, (*this), dynamicNbPrefetch, |
| {getDynamicNbPrefetch()}); |
| if (!status.succeeded()) |
| return status; |
| if (dynamicNbPrefetch.size() != 1) |
| return emitDefiniteFailure() |
| << "requires exactly one value for dynamic_nb_prefetch"; |
| nbPrefetch = dynamicNbPrefetch[0]; |
| } |
| if (nbPrefetch <= 0) |
| return emitSilenceableFailure(getLoc()) |
| << "nb_prefetch must be a positive integer."; |
| |
| // Find load operation of the operand. |
| auto maybeLoadOp = findProducerOfType<xegpu::LoadNdOp>(value); |
| if (!maybeLoadOp) |
| return emitSilenceableFailure(getLoc()) << "Could not find load op."; |
| auto loadOp = *maybeLoadOp; |
| if (loadOp.getMixedOffsets().size() == 0) { |
| auto diag = emitSilenceableFailure(getLoc()) |
| << "Load op must have offsets."; |
| diag.attachNote(loadOp.getLoc()) << "load op"; |
| return diag; |
| } |
| |
| // Find the parent scf.for loop. |
| auto forOp = loadOp->getParentOfType<scf::ForOp>(); |
| if (!forOp) { |
| auto diag = emitSilenceableFailure(getLoc()) |
| << "Load op is not contained in a scf.for loop."; |
| diag.attachNote(loadOp.getLoc()) << "load op"; |
| return diag; |
| } |
| |
| // Find descriptor op. |
| auto maybeDescOp = findProducerOfType<xegpu::CreateNdDescOp>(value); |
| if (!maybeDescOp) |
| return emitSilenceableFailure(getLoc()) << "Could not find descriptor op."; |
| auto descOp = *maybeDescOp; |
| if (descOp.getMixedOffsets().size() > 0) { |
| auto diag = emitSilenceableFailure(getLoc()) |
| << "desc op with offsets is not supported."; |
| diag.attachNote(descOp.getLoc()) << "desc op"; |
| } |
| |
| // Clone desc op outside the loop. |
| rewriter.setInsertionPoint(forOp); |
| auto newDescOp = |
| cast<xegpu::CreateNdDescOp>(rewriter.clone(*descOp.getOperation())); |
| |
| // Clone reduction loop to emit initial prefetches. |
| // Compute upper bound of the init loop: start + nbPrefetch * step. |
| auto nbPrefetchCst = |
| arith::ConstantIndexOp::create(rewriter, forOp.getLoc(), nbPrefetch); |
| auto nbStep = rewriter.createOrFold<arith::MulIOp>( |
| forOp.getLoc(), nbPrefetchCst, forOp.getStep()); |
| auto initUpBound = rewriter.createOrFold<arith::AddIOp>( |
| forOp.getLoc(), forOp.getLowerBound(), nbStep); |
| auto initForOp = |
| scf::ForOp::create(rewriter, forOp.getLoc(), forOp.getLowerBound(), |
| initUpBound, forOp.getStep()); |
| |
| auto ctx = rewriter.getContext(); |
| auto readCacheHint = |
| xegpu::CachePolicyAttr::get(ctx, xegpu::CachePolicy::CACHED); |
| |
| // Modify loadOp mixedOffsets by replacing the for loop induction variable |
| // with the given value. |
| auto getPrefetchOffsets = |
| [&](Value replacementVal) -> SmallVector<OpFoldResult> { |
| IRMapping mapping; |
| mapping.map(forOp.getInductionVar(), replacementVal); |
| SmallVector<Value> dynamicOffsets = |
| llvm::map_to_vector(loadOp.getOffsets(), [&](Value v) { |
| return mapping.lookupOrDefault(v); |
| }); |
| auto constOffsets = loadOp.getConstOffsets().value(); |
| return getMixedValues(constOffsets, dynamicOffsets, ctx); |
| }; |
| |
| // Insert prefetch op in init loop. |
| // Replace induction var with the init loop induction var. |
| rewriter.setInsertionPointToStart(initForOp.getBody()); |
| xegpu::PrefetchNdOp::create(rewriter, newDescOp.getLoc(), |
| newDescOp.getResult(), |
| getPrefetchOffsets(initForOp.getInductionVar()), |
| readCacheHint, readCacheHint, readCacheHint, |
| /*layout=*/nullptr); |
| |
| // Insert prefetch op in main loop. |
| // Calculate prefetch offset after the init prefetches have been issued. |
| rewriter.setInsertionPointToStart(forOp.getBody()); |
| auto prefetchOffset = arith::AddIOp::create(rewriter, forOp.getLoc(), |
| forOp.getInductionVar(), nbStep); |
| // Replace induction var with correct offset. |
| xegpu::PrefetchNdOp::create(rewriter, newDescOp.getLoc(), |
| newDescOp.getResult(), |
| getPrefetchOffsets(prefetchOffset), readCacheHint, |
| readCacheHint, readCacheHint, /*layout=*/nullptr); |
| |
| // Unroll the init loop. |
| if (failed(loopUnrollFull(initForOp))) |
| return emitSilenceableFailure(getLoc()) << "Failed to unroll the loop"; |
| |
| results.set(llvm::cast<OpResult>(getResult()), {newDescOp}); |
| |
| return DiagnosedSilenceableFailure::success(); |
| } |
| |
| void transform::InsertPrefetchOp::getEffects( |
| ::llvm::SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { |
| onlyReadsHandle(getTargetMutable(), effects); |
| onlyReadsHandle(getDynamicNbPrefetchMutable(), effects); |
| producesHandle(getOperation()->getOpResults(), effects); |
| modifiesPayload(effects); |
| } |
| |
| void transform::ConvertLayoutOp::build( |
| OpBuilder &builder, OperationState &ostate, Value target, |
| ArrayRef<OpFoldResult> mixedInputSgLayout, |
| ArrayRef<OpFoldResult> mixedInputSgData, |
| ArrayRef<OpFoldResult> mixedInputInstData, ArrayRef<int32_t> inputOrder, |
| ArrayRef<OpFoldResult> mixedTargetSgLayout, |
| ArrayRef<OpFoldResult> mixedTargetSgData, |
| ArrayRef<OpFoldResult> mixedTargetInstData, ArrayRef<int32_t> targetOrder) { |
| SmallVector<int64_t> staticInputSgLayout, staticInputSgData, |
| staticInputInstData; |
| SmallVector<Value> dynamicInputSgLayout, dynamicInputSgData, |
| dynamicInputInstData; |
| dispatchIndexOpFoldResults(mixedInputSgLayout, dynamicInputSgLayout, |
| staticInputSgLayout); |
| dispatchIndexOpFoldResults(mixedInputSgData, dynamicInputSgData, |
| staticInputSgData); |
| dispatchIndexOpFoldResults(mixedInputInstData, dynamicInputInstData, |
| staticInputInstData); |
| SmallVector<int64_t> staticTargetSgLayout, staticTargetSgData, |
| staticTargetInstData; |
| SmallVector<Value> dynamicTargetSgLayout, dynamicTargetSgData, |
| dynamicTargetInstData; |
| dispatchIndexOpFoldResults(mixedTargetSgLayout, dynamicTargetSgLayout, |
| staticTargetSgLayout); |
| dispatchIndexOpFoldResults(mixedTargetSgData, dynamicTargetSgData, |
| staticTargetSgData); |
| dispatchIndexOpFoldResults(mixedTargetInstData, dynamicTargetInstData, |
| staticTargetInstData); |
| build(builder, ostate, target.getType(), |
| /*target=*/target, |
| /*input_sg_layout=*/dynamicInputSgLayout, |
| /*input_sg_data=*/dynamicInputSgData, |
| /*input_inst_data=*/dynamicInputInstData, |
| /*target_sg_layout=*/dynamicTargetSgLayout, |
| /*target_sg_data=*/dynamicTargetSgData, |
| /*target_inst_data=*/dynamicTargetInstData, |
| /*input_order=*/inputOrder, |
| /*static_input_sg_layout=*/staticInputSgLayout, |
| /*static_input_sg_data=*/staticInputSgData, |
| /*static_input_inst_data=*/staticInputInstData, |
| /*static_target_sg_layout=*/staticTargetSgLayout, |
| /*static_target_sg_data=*/staticTargetSgData, |
| /*static_target_inst_data=*/staticTargetInstData, |
| /*target_order=*/targetOrder); |
| } |
| |
| DiagnosedSilenceableFailure |
| transform::ConvertLayoutOp::apply(transform::TransformRewriter &rewriter, |
| transform::TransformResults &results, |
| transform::TransformState &state) { |
| auto targetValues = state.getPayloadValues(getTarget()); |
| if (!llvm::hasSingleElement(targetValues)) |
| return emitDefiniteFailure() |
| << "requires exactly one target value handle (got " |
| << llvm::range_size(targetValues) << ")"; |
| auto value = *targetValues.begin(); |
| |
| // Construct layout attributes. |
| xegpu::LayoutAttr inputLayoutAttr = nullptr; |
| auto status = getLayoutAttrFromOperands( |
| getContext(), state, (*this), getMixedInputSgLayout(), |
| getMixedInputSgData(), getMixedInputInstData(), getInputOrder(), |
| inputLayoutAttr); |
| if (!status.succeeded()) |
| return status; |
| |
| xegpu::LayoutAttr targetLayoutAttr = nullptr; |
| status = getLayoutAttrFromOperands( |
| getContext(), state, (*this), getMixedTargetSgLayout(), |
| getMixedTargetSgData(), getMixedTargetInstData(), getTargetOrder(), |
| targetLayoutAttr); |
| if (!status.succeeded()) |
| return status; |
| |
| // Find first user op to define insertion point for layout conversion. |
| if (value.use_empty()) |
| return emitSilenceableFailure(getLoc()) |
| << "Value has no users to insert layout conversion."; |
| Operation *userOp = *value.getUsers().begin(); |
| |
| // Emit convert_layout op. |
| rewriter.setInsertionPoint(userOp); |
| auto convLayoutOp = |
| xegpu::ConvertLayoutOp::create(rewriter, value.getLoc(), value.getType(), |
| value, inputLayoutAttr, targetLayoutAttr); |
| // Replace load op result with the converted layout. |
| rewriter.replaceUsesWithIf( |
| value, convLayoutOp.getResult(), [&](OpOperand &use) { |
| return use.getOwner() != convLayoutOp.getOperation(); |
| }); |
| |
| results.set(llvm::cast<OpResult>(getResult()), {convLayoutOp}); |
| return DiagnosedSilenceableFailure::success(); |
| } |
| |
| void transform::ConvertLayoutOp::getEffects( |
| ::llvm::SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { |
| onlyReadsHandle(getTargetMutable(), effects); |
| onlyReadsHandle(getInputSgLayoutMutable(), effects); |
| onlyReadsHandle(getInputSgDataMutable(), effects); |
| onlyReadsHandle(getInputInstDataMutable(), effects); |
| onlyReadsHandle(getTargetSgLayoutMutable(), effects); |
| onlyReadsHandle(getTargetSgDataMutable(), effects); |
| onlyReadsHandle(getTargetInstDataMutable(), effects); |
| producesHandle(getOperation()->getOpResults(), effects); |
| modifiesPayload(effects); |
| } |
| |
| namespace { |
| class XeGPUTransformDialectExtension |
| : public transform::TransformDialectExtension< |
| XeGPUTransformDialectExtension> { |
| public: |
| MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(XeGPUTransformDialectExtension) |
| |
| using Base::Base; |
| |
| void init(); |
| }; |
| |
| void XeGPUTransformDialectExtension::init() { |
| declareGeneratedDialect<scf::SCFDialect>(); |
| declareGeneratedDialect<arith::ArithDialect>(); |
| declareGeneratedDialect<xegpu::XeGPUDialect>(); |
| |
| registerTransformOps< |
| #define GET_OP_LIST |
| #include "mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp.inc" |
| >(); |
| } |
| } // namespace |
| |
| #define GET_OP_CLASSES |
| #include "mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp.inc" |
| |
| void mlir::xegpu::registerTransformDialectExtension(DialectRegistry ®istry) { |
| registry.addExtensions<XeGPUTransformDialectExtension>(); |
| } |