blob: 153b9b170e5d3406f0f2c2bcb820ff0e99964cb2 [file] [log] [blame]
//===- SparseTensorPasses.cpp - Pass for autogen sparse tensor code -------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
#include "mlir/Dialect/Complex/IR/Complex.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Func/Transforms/FuncConversions.h"
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Dialect/SCF/Transforms/Patterns.h"
#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
#include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
namespace mlir {
#define GEN_PASS_DEF_SPARSEASSEMBLER
#define GEN_PASS_DEF_SPARSEREINTERPRETMAP
#define GEN_PASS_DEF_PRESPARSIFICATIONREWRITE
#define GEN_PASS_DEF_SPARSIFICATIONPASS
#define GEN_PASS_DEF_LOWERSPARSEITERATIONTOSCF
#define GEN_PASS_DEF_LOWERSPARSEOPSTOFOREACH
#define GEN_PASS_DEF_LOWERFOREACHTOSCF
#define GEN_PASS_DEF_SPARSETENSORCONVERSIONPASS
#define GEN_PASS_DEF_SPARSETENSORCODEGEN
#define GEN_PASS_DEF_SPARSEBUFFERREWRITE
#define GEN_PASS_DEF_SPARSEVECTORIZATION
#define GEN_PASS_DEF_SPARSEGPUCODEGEN
#define GEN_PASS_DEF_STAGESPARSEOPERATIONS
#define GEN_PASS_DEF_STORAGESPECIFIERTOLLVM
#include "mlir/Dialect/SparseTensor/Transforms/Passes.h.inc"
} // namespace mlir
using namespace mlir;
using namespace mlir::sparse_tensor;
namespace {
//===----------------------------------------------------------------------===//
// Passes implementation.
//===----------------------------------------------------------------------===//
struct SparseAssembler : public impl::SparseAssemblerBase<SparseAssembler> {
SparseAssembler() = default;
SparseAssembler(const SparseAssembler &pass) = default;
SparseAssembler(bool dO) { directOut = dO; }
void runOnOperation() override {
auto *ctx = &getContext();
RewritePatternSet patterns(ctx);
populateSparseAssembler(patterns, directOut);
(void)applyPatternsGreedily(getOperation(), std::move(patterns));
}
};
struct SparseReinterpretMap
: public impl::SparseReinterpretMapBase<SparseReinterpretMap> {
SparseReinterpretMap() = default;
SparseReinterpretMap(const SparseReinterpretMap &pass) = default;
SparseReinterpretMap(const SparseReinterpretMapOptions &options) {
scope = options.scope;
}
void runOnOperation() override {
auto *ctx = &getContext();
RewritePatternSet patterns(ctx);
populateSparseReinterpretMap(patterns, scope);
(void)applyPatternsGreedily(getOperation(), std::move(patterns));
}
};
struct PreSparsificationRewritePass
: public impl::PreSparsificationRewriteBase<PreSparsificationRewritePass> {
PreSparsificationRewritePass() = default;
PreSparsificationRewritePass(const PreSparsificationRewritePass &pass) =
default;
void runOnOperation() override {
auto *ctx = &getContext();
RewritePatternSet patterns(ctx);
populatePreSparsificationRewriting(patterns);
(void)applyPatternsGreedily(getOperation(), std::move(patterns));
}
};
struct SparsificationPass
: public impl::SparsificationPassBase<SparsificationPass> {
SparsificationPass() = default;
SparsificationPass(const SparsificationPass &pass) = default;
SparsificationPass(const SparsificationOptions &options) {
parallelization = options.parallelizationStrategy;
sparseEmitStrategy = options.sparseEmitStrategy;
enableRuntimeLibrary = options.enableRuntimeLibrary;
}
void runOnOperation() override {
auto *ctx = &getContext();
// Translate strategy flags to strategy options.
SparsificationOptions options(parallelization, sparseEmitStrategy,
enableRuntimeLibrary);
// Apply sparsification and cleanup rewriting.
RewritePatternSet patterns(ctx);
populateSparsificationPatterns(patterns, options);
scf::ForOp::getCanonicalizationPatterns(patterns, ctx);
(void)applyPatternsGreedily(getOperation(), std::move(patterns));
}
};
struct StageSparseOperationsPass
: public impl::StageSparseOperationsBase<StageSparseOperationsPass> {
StageSparseOperationsPass() = default;
StageSparseOperationsPass(const StageSparseOperationsPass &pass) = default;
void runOnOperation() override {
auto *ctx = &getContext();
RewritePatternSet patterns(ctx);
populateStageSparseOperationsPatterns(patterns);
(void)applyPatternsGreedily(getOperation(), std::move(patterns));
}
};
struct LowerSparseOpsToForeachPass
: public impl::LowerSparseOpsToForeachBase<LowerSparseOpsToForeachPass> {
LowerSparseOpsToForeachPass() = default;
LowerSparseOpsToForeachPass(const LowerSparseOpsToForeachPass &pass) =
default;
LowerSparseOpsToForeachPass(bool enableRT, bool convert) {
enableRuntimeLibrary = enableRT;
enableConvert = convert;
}
void runOnOperation() override {
auto *ctx = &getContext();
RewritePatternSet patterns(ctx);
populateLowerSparseOpsToForeachPatterns(patterns, enableRuntimeLibrary,
enableConvert);
(void)applyPatternsGreedily(getOperation(), std::move(patterns));
}
};
struct LowerForeachToSCFPass
: public impl::LowerForeachToSCFBase<LowerForeachToSCFPass> {
LowerForeachToSCFPass() = default;
LowerForeachToSCFPass(const LowerForeachToSCFPass &pass) = default;
void runOnOperation() override {
auto *ctx = &getContext();
RewritePatternSet patterns(ctx);
populateLowerForeachToSCFPatterns(patterns);
(void)applyPatternsGreedily(getOperation(), std::move(patterns));
}
};
struct LowerSparseIterationToSCFPass
: public impl::LowerSparseIterationToSCFBase<
LowerSparseIterationToSCFPass> {
LowerSparseIterationToSCFPass() = default;
LowerSparseIterationToSCFPass(const LowerSparseIterationToSCFPass &) =
default;
void runOnOperation() override {
auto *ctx = &getContext();
RewritePatternSet patterns(ctx);
SparseIterationTypeConverter converter;
ConversionTarget target(*ctx);
// The actual conversion.
target.addLegalDialect<arith::ArithDialect, linalg::LinalgDialect,
memref::MemRefDialect, scf::SCFDialect,
sparse_tensor::SparseTensorDialect>();
target.addIllegalOp<CoIterateOp, ExtractIterSpaceOp, ExtractValOp,
IterateOp>();
target.addLegalOp<UnrealizedConversionCastOp>();
populateLowerSparseIterationToSCFPatterns(converter, patterns);
if (failed(applyPartialConversion(getOperation(), target,
std::move(patterns))))
signalPassFailure();
}
};
struct SparseTensorConversionPass
: public impl::SparseTensorConversionPassBase<SparseTensorConversionPass> {
SparseTensorConversionPass() = default;
SparseTensorConversionPass(const SparseTensorConversionPass &pass) = default;
void runOnOperation() override {
auto *ctx = &getContext();
RewritePatternSet patterns(ctx);
SparseTensorTypeToPtrConverter converter;
ConversionTarget target(*ctx);
// Everything in the sparse dialect must go!
target.addIllegalDialect<SparseTensorDialect>();
// All dynamic rules below accept new function, call, return, and various
// tensor and bufferization operations as legal output of the rewriting
// provided that all sparse tensor types have been fully rewritten.
target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) {
return converter.isSignatureLegal(op.getFunctionType());
});
target.addDynamicallyLegalOp<func::CallOp>([&](func::CallOp op) {
return converter.isSignatureLegal(op.getCalleeType());
});
target.addDynamicallyLegalOp<func::ReturnOp>([&](func::ReturnOp op) {
return converter.isLegal(op.getOperandTypes());
});
target.addDynamicallyLegalOp<tensor::DimOp>([&](tensor::DimOp op) {
return converter.isLegal(op.getOperandTypes());
});
target.addDynamicallyLegalOp<tensor::CastOp>([&](tensor::CastOp op) {
return converter.isLegal(op.getSource().getType()) &&
converter.isLegal(op.getDest().getType());
});
target.addDynamicallyLegalOp<tensor::ExpandShapeOp>(
[&](tensor::ExpandShapeOp op) {
return converter.isLegal(op.getSrc().getType()) &&
converter.isLegal(op.getResult().getType());
});
target.addDynamicallyLegalOp<tensor::CollapseShapeOp>(
[&](tensor::CollapseShapeOp op) {
return converter.isLegal(op.getSrc().getType()) &&
converter.isLegal(op.getResult().getType());
});
target.addDynamicallyLegalOp<bufferization::AllocTensorOp>(
[&](bufferization::AllocTensorOp op) {
return converter.isLegal(op.getType());
});
target.addDynamicallyLegalOp<bufferization::DeallocTensorOp>(
[&](bufferization::DeallocTensorOp op) {
return converter.isLegal(op.getTensor().getType());
});
// The following operations and dialects may be introduced by the
// rewriting rules, and are therefore marked as legal.
target.addLegalOp<complex::ConstantOp, complex::NotEqualOp, linalg::FillOp,
linalg::YieldOp, tensor::ExtractOp,
tensor::FromElementsOp>();
target.addLegalDialect<
arith::ArithDialect, bufferization::BufferizationDialect,
LLVM::LLVMDialect, memref::MemRefDialect, scf::SCFDialect>();
// Populate with rules and apply rewriting rules.
populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(patterns,
converter);
populateCallOpTypeConversionPattern(patterns, converter);
scf::populateSCFStructuralTypeConversionsAndLegality(converter, patterns,
target);
populateSparseTensorConversionPatterns(converter, patterns);
if (failed(applyPartialConversion(getOperation(), target,
std::move(patterns))))
signalPassFailure();
}
};
struct SparseTensorCodegenPass
: public impl::SparseTensorCodegenBase<SparseTensorCodegenPass> {
SparseTensorCodegenPass() = default;
SparseTensorCodegenPass(const SparseTensorCodegenPass &pass) = default;
SparseTensorCodegenPass(bool createDeallocs, bool enableInit) {
createSparseDeallocs = createDeallocs;
enableBufferInitialization = enableInit;
}
void runOnOperation() override {
auto *ctx = &getContext();
RewritePatternSet patterns(ctx);
SparseTensorTypeToBufferConverter converter;
ConversionTarget target(*ctx);
// Most ops in the sparse dialect must go!
target.addIllegalDialect<SparseTensorDialect>();
target.addLegalOp<SortOp>();
target.addLegalOp<PushBackOp>();
// Storage specifier outlives sparse tensor pipeline.
target.addLegalOp<GetStorageSpecifierOp>();
target.addLegalOp<SetStorageSpecifierOp>();
target.addLegalOp<StorageSpecifierInitOp>();
// Note that tensor::FromElementsOp might be yield after lowering unpack.
target.addLegalOp<tensor::FromElementsOp>();
// All dynamic rules below accept new function, call, return, and
// various tensor and bufferization operations as legal output of the
// rewriting provided that all sparse tensor types have been fully
// rewritten.
target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) {
return converter.isSignatureLegal(op.getFunctionType());
});
target.addDynamicallyLegalOp<func::CallOp>([&](func::CallOp op) {
return converter.isSignatureLegal(op.getCalleeType());
});
target.addDynamicallyLegalOp<func::ReturnOp>([&](func::ReturnOp op) {
return converter.isLegal(op.getOperandTypes());
});
target.addDynamicallyLegalOp<bufferization::AllocTensorOp>(
[&](bufferization::AllocTensorOp op) {
return converter.isLegal(op.getType());
});
target.addDynamicallyLegalOp<bufferization::DeallocTensorOp>(
[&](bufferization::DeallocTensorOp op) {
return converter.isLegal(op.getTensor().getType());
});
// The following operations and dialects may be introduced by the
// codegen rules, and are therefore marked as legal.
target.addLegalOp<linalg::FillOp, linalg::YieldOp>();
target.addLegalDialect<
arith::ArithDialect, bufferization::BufferizationDialect,
complex::ComplexDialect, memref::MemRefDialect, scf::SCFDialect>();
target.addLegalOp<UnrealizedConversionCastOp>();
// Populate with rules and apply rewriting rules.
populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(patterns,
converter);
scf::populateSCFStructuralTypeConversionsAndLegality(converter, patterns,
target);
populateSparseTensorCodegenPatterns(
converter, patterns, createSparseDeallocs, enableBufferInitialization);
if (failed(applyPartialConversion(getOperation(), target,
std::move(patterns))))
signalPassFailure();
}
};
struct SparseBufferRewritePass
: public impl::SparseBufferRewriteBase<SparseBufferRewritePass> {
SparseBufferRewritePass() = default;
SparseBufferRewritePass(const SparseBufferRewritePass &pass) = default;
SparseBufferRewritePass(bool enableInit) {
enableBufferInitialization = enableInit;
}
void runOnOperation() override {
auto *ctx = &getContext();
RewritePatternSet patterns(ctx);
populateSparseBufferRewriting(patterns, enableBufferInitialization);
(void)applyPatternsGreedily(getOperation(), std::move(patterns));
}
};
struct SparseVectorizationPass
: public impl::SparseVectorizationBase<SparseVectorizationPass> {
SparseVectorizationPass() = default;
SparseVectorizationPass(const SparseVectorizationPass &pass) = default;
SparseVectorizationPass(unsigned vl, bool vla, bool sidx32) {
vectorLength = vl;
enableVLAVectorization = vla;
enableSIMDIndex32 = sidx32;
}
void runOnOperation() override {
if (vectorLength == 0)
return signalPassFailure();
auto *ctx = &getContext();
RewritePatternSet patterns(ctx);
populateSparseVectorizationPatterns(
patterns, vectorLength, enableVLAVectorization, enableSIMDIndex32);
vector::populateVectorToVectorCanonicalizationPatterns(patterns);
(void)applyPatternsGreedily(getOperation(), std::move(patterns));
}
};
struct SparseGPUCodegenPass
: public impl::SparseGPUCodegenBase<SparseGPUCodegenPass> {
SparseGPUCodegenPass() = default;
SparseGPUCodegenPass(const SparseGPUCodegenPass &pass) = default;
SparseGPUCodegenPass(unsigned nT, bool enableRT) {
numThreads = nT;
enableRuntimeLibrary = enableRT;
}
void runOnOperation() override {
auto *ctx = &getContext();
RewritePatternSet patterns(ctx);
if (numThreads == 0)
populateSparseGPULibgenPatterns(patterns, enableRuntimeLibrary);
else
populateSparseGPUCodegenPatterns(patterns, numThreads);
(void)applyPatternsGreedily(getOperation(), std::move(patterns));
}
};
struct StorageSpecifierToLLVMPass
: public impl::StorageSpecifierToLLVMBase<StorageSpecifierToLLVMPass> {
StorageSpecifierToLLVMPass() = default;
void runOnOperation() override {
auto *ctx = &getContext();
ConversionTarget target(*ctx);
RewritePatternSet patterns(ctx);
StorageSpecifierToLLVMTypeConverter converter;
// All ops in the sparse dialect must go!
target.addIllegalDialect<SparseTensorDialect>();
target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) {
return converter.isSignatureLegal(op.getFunctionType());
});
target.addDynamicallyLegalOp<func::CallOp>([&](func::CallOp op) {
return converter.isSignatureLegal(op.getCalleeType());
});
target.addDynamicallyLegalOp<func::ReturnOp>([&](func::ReturnOp op) {
return converter.isLegal(op.getOperandTypes());
});
target.addLegalDialect<arith::ArithDialect, LLVM::LLVMDialect>();
populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(patterns,
converter);
populateCallOpTypeConversionPattern(patterns, converter);
populateBranchOpInterfaceTypeConversionPattern(patterns, converter);
populateReturnOpTypeConversionPattern(patterns, converter);
scf::populateSCFStructuralTypeConversionsAndLegality(converter, patterns,
target);
populateStorageSpecifierToLLVMPatterns(converter, patterns);
if (failed(applyPartialConversion(getOperation(), target,
std::move(patterns))))
signalPassFailure();
}
};
} // namespace
//===----------------------------------------------------------------------===//
// Pass creation methods.
//===----------------------------------------------------------------------===//
std::unique_ptr<Pass> mlir::createSparseAssembler() {
return std::make_unique<SparseAssembler>();
}
std::unique_ptr<Pass> mlir::createSparseReinterpretMapPass() {
return std::make_unique<SparseReinterpretMap>();
}
std::unique_ptr<Pass>
mlir::createSparseReinterpretMapPass(ReinterpretMapScope scope) {
SparseReinterpretMapOptions options;
options.scope = scope;
return std::make_unique<SparseReinterpretMap>(options);
}
std::unique_ptr<Pass> mlir::createPreSparsificationRewritePass() {
return std::make_unique<PreSparsificationRewritePass>();
}
std::unique_ptr<Pass> mlir::createSparsificationPass() {
return std::make_unique<SparsificationPass>();
}
std::unique_ptr<Pass>
mlir::createSparsificationPass(const SparsificationOptions &options) {
return std::make_unique<SparsificationPass>(options);
}
std::unique_ptr<Pass> mlir::createStageSparseOperationsPass() {
return std::make_unique<StageSparseOperationsPass>();
}
std::unique_ptr<Pass> mlir::createLowerSparseOpsToForeachPass() {
return std::make_unique<LowerSparseOpsToForeachPass>();
}
std::unique_ptr<Pass>
mlir::createLowerSparseOpsToForeachPass(bool enableRT, bool enableConvert) {
return std::make_unique<LowerSparseOpsToForeachPass>(enableRT, enableConvert);
}
std::unique_ptr<Pass> mlir::createLowerForeachToSCFPass() {
return std::make_unique<LowerForeachToSCFPass>();
}
std::unique_ptr<Pass> mlir::createLowerSparseIterationToSCFPass() {
return std::make_unique<LowerSparseIterationToSCFPass>();
}
std::unique_ptr<Pass> mlir::createSparseTensorConversionPass() {
return std::make_unique<SparseTensorConversionPass>();
}
std::unique_ptr<Pass> mlir::createSparseTensorCodegenPass() {
return std::make_unique<SparseTensorCodegenPass>();
}
std::unique_ptr<Pass>
mlir::createSparseTensorCodegenPass(bool createSparseDeallocs,
bool enableBufferInitialization) {
return std::make_unique<SparseTensorCodegenPass>(createSparseDeallocs,
enableBufferInitialization);
}
std::unique_ptr<Pass> mlir::createSparseBufferRewritePass() {
return std::make_unique<SparseBufferRewritePass>();
}
std::unique_ptr<Pass>
mlir::createSparseBufferRewritePass(bool enableBufferInitialization) {
return std::make_unique<SparseBufferRewritePass>(enableBufferInitialization);
}
std::unique_ptr<Pass> mlir::createSparseVectorizationPass() {
return std::make_unique<SparseVectorizationPass>();
}
std::unique_ptr<Pass>
mlir::createSparseVectorizationPass(unsigned vectorLength,
bool enableVLAVectorization,
bool enableSIMDIndex32) {
return std::make_unique<SparseVectorizationPass>(
vectorLength, enableVLAVectorization, enableSIMDIndex32);
}
std::unique_ptr<Pass> mlir::createSparseGPUCodegenPass() {
return std::make_unique<SparseGPUCodegenPass>();
}
std::unique_ptr<Pass> mlir::createSparseGPUCodegenPass(unsigned numThreads,
bool enableRT) {
return std::make_unique<SparseGPUCodegenPass>(numThreads, enableRT);
}
std::unique_ptr<Pass> mlir::createStorageSpecifierToLLVMPass() {
return std::make_unique<StorageSpecifierToLLVMPass>();
}