blob: 49a882f31dccd30c37454797d785704e34bf0052 [file] [log] [blame]
//===- DenseBufferizationPass.cpp - Dense bufferization pass --------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
#include "mlir/Dialect/Bufferization/Transforms/Bufferize.h"
#include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
using namespace mlir;
using namespace mlir::func;
namespace mlir {
namespace sparse_tensor {
/// Return `true` if one of the given types is a sparse tensor type.
static bool containsSparseTensor(TypeRange types) {
for (Type t : types)
if (getSparseTensorEncoding(t))
return true;
return false;
}
/// A pass that bufferizes only dense tensor ops and ignores all sparse tensor
/// ops. No buffer copies are inserted. All tensor OpOperands must be
/// inplacable.
class BufferizeDenseOpsPass
: public PassWrapper<BufferizeDenseOpsPass, OperationPass<ModuleOp>> {
public:
BufferizeDenseOpsPass(
const bufferization::OneShotBufferizationOptions &options)
: options(options) {}
void runOnOperation() override {
// Disallow all sparse tensor ops, so that only dense tensor ops are
// bufferized.
bufferization::OpFilter opFilter;
opFilter.allowOperation([&](Operation *op) {
if (containsSparseTensor(TypeRange(op->getResults())) ||
containsSparseTensor(TypeRange(op->getOperands())))
return false;
if (auto funcOp = dyn_cast<func::FuncOp>(op)) {
FunctionType funcType = funcOp.getFunctionType();
if (containsSparseTensor(funcType.getInputs()) ||
containsSparseTensor(funcType.getResults()))
return false;
}
return true;
});
if (failed(bufferization::bufferizeOp(getOperation(), options,
/*copyBeforeWrite=*/false,
&opFilter)))
signalPassFailure();
}
private:
bufferization::OneShotBufferizationOptions options;
};
} // namespace sparse_tensor
} // namespace mlir
std::unique_ptr<Pass> mlir::createDenseBufferizationPass(
const bufferization::OneShotBufferizationOptions &options) {
return std::make_unique<mlir::sparse_tensor::BufferizeDenseOpsPass>(options);
}