blob: 5da1b76e929be49e6d6fa7053aa89bacee886c04 [file] [log] [blame]
//===- ForallToFor.cpp - scf.forall to scf.for loop conversion ------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// Transforms SCF.ForallOp's into SCF.ForOp's.
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/SCF/Transforms/Passes.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/SCF/Transforms/Transforms.h"
#include "mlir/IR/PatternMatch.h"
namespace mlir {
#define GEN_PASS_DEF_SCFFORALLTOFORLOOP
#include "mlir/Dialect/SCF/Transforms/Passes.h.inc"
} // namespace mlir
using namespace llvm;
using namespace mlir;
using scf::ForallOp;
using scf::ForOp;
using scf::LoopNest;
LogicalResult
mlir::scf::forallToForLoop(RewriterBase &rewriter, scf::ForallOp forallOp,
SmallVectorImpl<Operation *> *results) {
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPoint(forallOp);
Location loc = forallOp.getLoc();
SmallVector<Value> lbs = forallOp.getLowerBound(rewriter);
SmallVector<Value> ubs = forallOp.getUpperBound(rewriter);
SmallVector<Value> steps = forallOp.getStep(rewriter);
LoopNest loopNest = scf::buildLoopNest(rewriter, loc, lbs, ubs, steps);
SmallVector<Value> ivs = llvm::map_to_vector(
loopNest.loops, [](scf::ForOp loop) { return loop.getInductionVar(); });
Block *innermostBlock = loopNest.loops.back().getBody();
rewriter.eraseOp(forallOp.getBody()->getTerminator());
rewriter.inlineBlockBefore(forallOp.getBody(), innermostBlock,
innermostBlock->getTerminator()->getIterator(),
ivs);
rewriter.eraseOp(forallOp);
if (results) {
llvm::move(loopNest.loops, std::back_inserter(*results));
}
return success();
}
namespace {
struct ForallToForLoop : public impl::SCFForallToForLoopBase<ForallToForLoop> {
void runOnOperation() override {
Operation *parentOp = getOperation();
IRRewriter rewriter(parentOp->getContext());
parentOp->walk([&](scf::ForallOp forallOp) {
if (failed(scf::forallToForLoop(rewriter, forallOp))) {
return signalPassFailure();
}
});
}
};
} // namespace
std::unique_ptr<Pass> mlir::createForallToForLoopPass() {
return std::make_unique<ForallToForLoop>();
}