blob: 2089b6579ccb587fb97627bcce5b80a5e847873b [file] [log] [blame]
//====- LowerToAffineLoops.cpp - Partial lowering from Toy to Affine+Std --===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements a partial lowering of Toy operations to a combination of
// affine loops and standard operations. This lowering expects that all calls
// have been inlined, and all shapes have been resolved.
//
//===----------------------------------------------------------------------===//
#include "toy/Dialect.h"
#include "toy/Passes.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"
#include "llvm/ADT/Sequence.h"
using namespace mlir;
//===----------------------------------------------------------------------===//
// ToyToAffine RewritePatterns
//===----------------------------------------------------------------------===//
/// Convert the given TensorType into the corresponding MemRefType.
static MemRefType convertTensorToMemRef(TensorType type) {
assert(type.hasRank() && "expected only ranked shapes");
return MemRefType::get(type.getShape(), type.getElementType());
}
/// Insert an allocation and deallocation for the given MemRefType.
static Value insertAllocAndDealloc(MemRefType type, Location loc,
PatternRewriter &rewriter) {
auto alloc = rewriter.create<AllocOp>(loc, type);
// Make sure to allocate at the beginning of the block.
auto *parentBlock = alloc->getBlock();
alloc->moveBefore(&parentBlock->front());
// Make sure to deallocate this alloc at the end of the block. This is fine
// as toy functions have no control flow.
auto dealloc = rewriter.create<DeallocOp>(loc, alloc);
dealloc->moveBefore(&parentBlock->back());
return alloc;
}
/// This defines the function type used to process an iteration of a lowered
/// loop. It takes as input an OpBuilder, an range of memRefOperands
/// corresponding to the operands of the input operation, and the range of loop
/// induction variables for the iteration. It returns a value to store at the
/// current index of the iteration.
using LoopIterationFn = function_ref<Value(
OpBuilder &builder, ValueRange memRefOperands, ValueRange loopIvs)>;
static void lowerOpToLoops(Operation *op, ArrayRef<Value> operands,
PatternRewriter &rewriter,
LoopIterationFn processIteration) {
auto tensorType = (*op->result_type_begin()).cast<TensorType>();
auto loc = op->getLoc();
// Insert an allocation and deallocation for the result of this operation.
auto memRefType = convertTensorToMemRef(tensorType);
auto alloc = insertAllocAndDealloc(memRefType, loc, rewriter);
// Create a nest of affine loops, with one loop per dimension of the shape.
// The buildAffineLoopNest function takes a callback that is used to construct
// the body of the innermost loop given a builder, a location and a range of
// loop induction variables.
SmallVector<int64_t, 4> lowerBounds(tensorType.getRank(), /*Value=*/0);
SmallVector<int64_t, 4> steps(tensorType.getRank(), /*Value=*/1);
buildAffineLoopNest(
rewriter, loc, lowerBounds, tensorType.getShape(), steps,
[&](OpBuilder &nestedBuilder, Location loc, ValueRange ivs) {
// Call the processing function with the rewriter, the memref operands,
// and the loop induction variables. This function will return the value
// to store at the current index.
Value valueToStore = processIteration(nestedBuilder, operands, ivs);
nestedBuilder.create<AffineStoreOp>(loc, valueToStore, alloc, ivs);
});
// Replace this operation with the generated alloc.
rewriter.replaceOp(op, alloc);
}
namespace {
//===----------------------------------------------------------------------===//
// ToyToAffine RewritePatterns: Binary operations
//===----------------------------------------------------------------------===//
template <typename BinaryOp, typename LoweredBinaryOp>
struct BinaryOpLowering : public ConversionPattern {
BinaryOpLowering(MLIRContext *ctx)
: ConversionPattern(BinaryOp::getOperationName(), 1, ctx) {}
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const final {
auto loc = op->getLoc();
lowerOpToLoops(
op, operands, rewriter,
[loc](OpBuilder &builder, ValueRange memRefOperands,
ValueRange loopIvs) {
// Generate an adaptor for the remapped operands of the BinaryOp. This
// allows for using the nice named accessors that are generated by the
// ODS.
typename BinaryOp::Adaptor binaryAdaptor(memRefOperands);
// Generate loads for the element of 'lhs' and 'rhs' at the inner
// loop.
auto loadedLhs =
builder.create<AffineLoadOp>(loc, binaryAdaptor.lhs(), loopIvs);
auto loadedRhs =
builder.create<AffineLoadOp>(loc, binaryAdaptor.rhs(), loopIvs);
// Create the binary operation performed on the loaded values.
return builder.create<LoweredBinaryOp>(loc, loadedLhs, loadedRhs);
});
return success();
}
};
using AddOpLowering = BinaryOpLowering<toy::AddOp, AddFOp>;
using MulOpLowering = BinaryOpLowering<toy::MulOp, MulFOp>;
//===----------------------------------------------------------------------===//
// ToyToAffine RewritePatterns: Constant operations
//===----------------------------------------------------------------------===//
struct ConstantOpLowering : public OpRewritePattern<toy::ConstantOp> {
using OpRewritePattern<toy::ConstantOp>::OpRewritePattern;
LogicalResult matchAndRewrite(toy::ConstantOp op,
PatternRewriter &rewriter) const final {
DenseElementsAttr constantValue = op.value();
Location loc = op.getLoc();
// When lowering the constant operation, we allocate and assign the constant
// values to a corresponding memref allocation.
auto tensorType = op.getType().cast<TensorType>();
auto memRefType = convertTensorToMemRef(tensorType);
auto alloc = insertAllocAndDealloc(memRefType, loc, rewriter);
// We will be generating constant indices up-to the largest dimension.
// Create these constants up-front to avoid large amounts of redundant
// operations.
auto valueShape = memRefType.getShape();
SmallVector<Value, 8> constantIndices;
if (!valueShape.empty()) {
for (auto i : llvm::seq<int64_t>(
0, *std::max_element(valueShape.begin(), valueShape.end())))
constantIndices.push_back(rewriter.create<ConstantIndexOp>(loc, i));
} else {
// This is the case of a tensor of rank 0.
constantIndices.push_back(rewriter.create<ConstantIndexOp>(loc, 0));
}
// The constant operation represents a multi-dimensional constant, so we
// will need to generate a store for each of the elements. The following
// functor recursively walks the dimensions of the constant shape,
// generating a store when the recursion hits the base case.
SmallVector<Value, 2> indices;
auto valueIt = constantValue.getValues<FloatAttr>().begin();
std::function<void(uint64_t)> storeElements = [&](uint64_t dimension) {
// The last dimension is the base case of the recursion, at this point
// we store the element at the given index.
if (dimension == valueShape.size()) {
rewriter.create<AffineStoreOp>(
loc, rewriter.create<ConstantOp>(loc, *valueIt++), alloc,
llvm::makeArrayRef(indices));
return;
}
// Otherwise, iterate over the current dimension and add the indices to
// the list.
for (uint64_t i = 0, e = valueShape[dimension]; i != e; ++i) {
indices.push_back(constantIndices[i]);
storeElements(dimension + 1);
indices.pop_back();
}
};
// Start the element storing recursion from the first dimension.
storeElements(/*dimension=*/0);
// Replace this operation with the generated alloc.
rewriter.replaceOp(op, alloc);
return success();
}
};
//===----------------------------------------------------------------------===//
// ToyToAffine RewritePatterns: Return operations
//===----------------------------------------------------------------------===//
struct ReturnOpLowering : public OpRewritePattern<toy::ReturnOp> {
using OpRewritePattern<toy::ReturnOp>::OpRewritePattern;
LogicalResult matchAndRewrite(toy::ReturnOp op,
PatternRewriter &rewriter) const final {
// During this lowering, we expect that all function calls have been
// inlined.
if (op.hasOperand())
return failure();
// We lower "toy.return" directly to "std.return".
rewriter.replaceOpWithNewOp<ReturnOp>(op);
return success();
}
};
//===----------------------------------------------------------------------===//
// ToyToAffine RewritePatterns: Transpose operations
//===----------------------------------------------------------------------===//
struct TransposeOpLowering : public ConversionPattern {
TransposeOpLowering(MLIRContext *ctx)
: ConversionPattern(toy::TransposeOp::getOperationName(), 1, ctx) {}
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const final {
auto loc = op->getLoc();
lowerOpToLoops(op, operands, rewriter,
[loc](OpBuilder &builder, ValueRange memRefOperands,
ValueRange loopIvs) {
// Generate an adaptor for the remapped operands of the
// TransposeOp. This allows for using the nice named
// accessors that are generated by the ODS.
toy::TransposeOpAdaptor transposeAdaptor(memRefOperands);
Value input = transposeAdaptor.input();
// Transpose the elements by generating a load from the
// reverse indices.
SmallVector<Value, 2> reverseIvs(llvm::reverse(loopIvs));
return builder.create<AffineLoadOp>(loc, input,
reverseIvs);
});
return success();
}
};
} // end anonymous namespace.
//===----------------------------------------------------------------------===//
// ToyToAffineLoweringPass
//===----------------------------------------------------------------------===//
/// This is a partial lowering to affine loops of the toy operations that are
/// computationally intensive (like matmul for example...) while keeping the
/// rest of the code in the Toy dialect.
namespace {
struct ToyToAffineLoweringPass
: public PassWrapper<ToyToAffineLoweringPass, FunctionPass> {
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<AffineDialect, StandardOpsDialect>();
}
void runOnFunction() final;
};
} // end anonymous namespace.
void ToyToAffineLoweringPass::runOnFunction() {
auto function = getFunction();
// We only lower the main function as we expect that all other functions have
// been inlined.
if (function.getName() != "main")
return;
// Verify that the given main has no inputs and results.
if (function.getNumArguments() || function.getType().getNumResults()) {
function.emitError("expected 'main' to have 0 inputs and 0 results");
return signalPassFailure();
}
// The first thing to define is the conversion target. This will define the
// final target for this lowering.
ConversionTarget target(getContext());
// We define the specific operations, or dialects, that are legal targets for
// this lowering. In our case, we are lowering to a combination of the
// `Affine` and `Standard` dialects.
target.addLegalDialect<AffineDialect, StandardOpsDialect>();
// We also define the Toy dialect as Illegal so that the conversion will fail
// if any of these operations are *not* converted. Given that we actually want
// a partial lowering, we explicitly mark the Toy operations that don't want
// to lower, `toy.print`, as `legal`.
target.addIllegalDialect<toy::ToyDialect>();
target.addLegalOp<toy::PrintOp>();
// Now that the conversion target has been defined, we just need to provide
// the set of patterns that will lower the Toy operations.
OwningRewritePatternList patterns;
patterns.insert<AddOpLowering, ConstantOpLowering, MulOpLowering,
ReturnOpLowering, TransposeOpLowering>(&getContext());
// With the target and rewrite patterns defined, we can now attempt the
// conversion. The conversion will signal failure if any of our `illegal`
// operations were not converted successfully.
if (failed(
applyPartialConversion(getFunction(), target, std::move(patterns))))
signalPassFailure();
}
/// Create a pass for lowering operations in the `Affine` and `Std` dialects,
/// for a subset of the Toy IR (e.g. matmul).
std::unique_ptr<Pass> mlir::toy::createLowerToAffinePass() {
return std::make_unique<ToyToAffineLoweringPass>();
}