blob: 82e60800f21be5532b43cd29d613142d65332285 [file] [log] [blame]
//===- Transforms.h - Linalg transformations as patterns --------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#ifndef DIALECT_LINALG_TRANSFORMS_TRANSFORMS_H_
#define DIALECT_LINALG_TRANSFORMS_TRANSFORMS_H_
#include "mlir/Conversion/VectorToSCF/VectorToSCF.h"
#include "mlir/Dialect/Linalg/Utils/Utils.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/Utils.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Utils/StaticValueUtils.h"
#include "mlir/Dialect/Vector/VectorTransforms.h"
#include "mlir/Dialect/X86Vector/Transforms.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Transforms/DialectConversion.h"
#include "llvm/ADT/SmallBitVector.h"
#include "llvm/ADT/SmallSet.h"
namespace mlir {
namespace bufferization {
class BufferizeTypeConverter;
} // namespace bufferization
class FrozenRewritePatternSet;
namespace linalg {
struct LinalgElementwiseFusionOptions;
struct LinalgFusionOptions;
struct LinalgTilingOptions;
/// Default function to control reshape folding. Skips folding unit dimension
/// reshapes.
bool skipUnitDimReshape(const OpResult &producer, OpOperand &consumer);
//===----------------------------------------------------------------------===//
// Transformations exposed as function calls.
//===----------------------------------------------------------------------===//
using LinalgLoops = SmallVector<Operation *, 4>;
/// [DEPRECATED] Populates patterns for vectorization of all ConvN-D ops.
void populateConvVectorizationPatterns(
MLIRContext *context, SmallVectorImpl<RewritePatternSet> &patterns,
ArrayRef<int64_t> tileSizes);
/// Populates patterns for vectorizing low-D convolution ops. This is a step in
/// progressive lowering for convolution ops, it assume high-D convolution ops
/// were decomposed previously.
void populateConvolutionVectorizationPatterns(RewritePatternSet &patterns,
PatternBenefit benefit = 1);
/// Populate patterns that convert `ElementwiseMappable` ops to linalg
/// parallel loops.
void populateElementwiseToLinalgConversionPatterns(RewritePatternSet &patterns);
/// Function type which is used to control when to stop fusion. It is expected
/// that OpOperand is not modified in the callback. The OpOperand is not marked
/// as const to allow callers to use non-const methods.
using ControlElementwiseOpsFusionFn =
std::function<bool(const OpResult &producer, OpOperand &consumer)>;
/// Patterns to fold an expanding (collapsing) tensor_reshape operation with its
/// producer (consumer) generic operation by expanding the dimensionality of the
/// loop in the generic op.
void populateFoldReshapeOpsByExpansionPatterns(
RewritePatternSet &patterns,
ControlElementwiseOpsFusionFn controlFoldingReshapes = skipUnitDimReshape);
/// Patterns to fold a collapsing (expanding) tensor_reshape operation with its
/// producer (consumer) generic operation by linearizing the indexing map used
/// to access the source (target) of the reshape operation in the generic
/// operation.
void populateFoldReshapeOpsByLinearizationPatterns(RewritePatternSet &patterns);
/// Patterns to fold a collapsing (expanding) tensor_reshape operation with its
/// producer (consumer) generic operation by linearizing the indexing map used
/// to access the source (target) of the reshape operation in the generic
/// operation. The patterns are applied only when the tensor reshape involved is
/// collapsing (introducing) unit-extent dimensions.
void populateFoldUnitDimsReshapeOpsByLinearizationPatterns(
RewritePatternSet &patterns);
/// Populates the given list with patterns to bufferize linalg ops.
void populateLinalgBufferizePatterns(
bufferization::BufferizeTypeConverter &converter,
RewritePatternSet &patterns);
/// Create linalg op on buffers given the original tensor-based operation and
/// the buffers for the outputs.
LinalgOp createLinalgOpOnBuffers(ConversionPatternRewriter &rewriter,
LinalgOp linalgOp, ValueRange inputs,
ValueRange outputs);
/// Patterns to fold unit-extent dimensions in operands/results of linalg ops on
/// tensors.
void populateFoldUnitExtentDimsPatterns(RewritePatternSet &patterns);
/// Patterns that are used to inline constant operands into linalg generic ops.
void populateInlineConstantOperandsPatterns(RewritePatternSet &patterns);
/// Pattern to convert TiledLoopOp to SCF loops.
void populateTiledLoopToSCFPattern(RewritePatternSet &patterns);
/// Options that control fusion of elementwise operations.
struct LinalgElementwiseFusionOptions {
/// Enable fusion of reshapes into the shape with elementwise operations. By
/// default it is disabled for unit dimensions reshape.
ControlElementwiseOpsFusionFn controlFoldingReshapesFn = skipUnitDimReshape;
LinalgElementwiseFusionOptions &
setControlFoldingReshapes(ControlElementwiseOpsFusionFn fun) {
controlFoldingReshapesFn = std::move(fun);
return *this;
}
/// Function that allows the caller to control when to stop fusion. Once a
/// producer is deemed fusable with the consumer (structurally), this callback
/// can be used to abort the fusion based on non-structural constraints. This
/// is the hook for cost models to control the amount of fusion done.
ControlElementwiseOpsFusionFn controlElementwiseOpsFusionFn =
[](const OpResult & /*producer */, OpOperand & /*consumer */) {
return true;
};
LinalgElementwiseFusionOptions &
setControlElementwiseOpsFusionFn(ControlElementwiseOpsFusionFn fun) {
controlElementwiseOpsFusionFn = std::move(fun);
return *this;
}
};
/// Patterns for fusing linalg operation on tensors.
void populateElementwiseOpsFusionPatterns(
RewritePatternSet &patterns,
LinalgElementwiseFusionOptions options = LinalgElementwiseFusionOptions());
/// Patterns to push reshape op towards the end of the graph in order to expose
/// more fusion opportunities.
void populatePushReshapeOpsPatterns(RewritePatternSet &patterns);
/// Performs standalone tiling of a single LinalgOp by `tileSizes`.
/// and permute the loop nest according to `interchangeVector`
/// The permutation is expressed as a list of integers that specify
/// the new ordering of the loop nest. The length of `interchangeVector`
/// must be equal to the length of `tileSizes`.
/// An empty vector is interpreted as the identity permutation and the
/// transformation returns early.
///
/// Returns a struct containing the tiled loops in the specified order
/// and the cloned op if successful, llvm::None otherwise.
///
/// E.g. the permutation `(i,j,k) -> (j,k,i)` is expressed by
/// `interchangeVector = [1,2,0]`. All values in `interchangeVector` must be
/// integers, in the range 0..`tileSizes.size()` without duplications
/// (i.e. `[1,1,2]` is an invalid permutation).
struct TiledLinalgOp {
LinalgOp op;
SmallVector<Operation *, 8> loops;
SmallVector<Value, 4> tensorResults;
};
FailureOr<TiledLinalgOp> tileLinalgOp(OpBuilder &b, LinalgOp op,
const LinalgTilingOptions &options);
/// Fuse a sequence of linalg operations (`ops`) using tile-and-fuse. This
/// proceeds as follows:
/// - Find outer parallel loops in these ops that can be fused.
/// - Tile fusable outer parallel loops of the last operation in the sequence.
/// - Fuse the remaining operations with the tiled operation
///
/// For example, consider the sequence of matmul below
///
/// linalg.matmul ins(%arg0, %arg1 : memref<256x32xf32>, memref<32x32xf32>)
/// outs(%arg2 : memref<256x32xf32>)
/// linalg.matmul ins(%arg2, %arg3 : memref<256x32xf32>, memref<32x32xf32>)
/// outs(%arg4 : memref<256x32xf32>)
///
/// It is legal to fuse the RAW dependence (through %arg2) by only fusing the
/// matmuls row-wise. For example, the fused computation for the above is shown
/// below. The outer `scf.parallel` loop is the "fused" loop obtained by tiling
/// along the rows of the matrix. The entire rows of the first matmul operation
/// need to be computed before they can be used for the second matmul. The
/// second matmul is further tiled (similar to normal tiling).
///
/// #map0 = affine_map<(d0, d1)[s0] -> (d0 * 32 + s0 + d1)>
/// #map1 = affine_map<(d0, d1) -> (d0 * 32 + d1)>
/// scf.parallel (%arg5) = (%c0) to (%c256) step (%c16) {
/// %0 = subview %arg2[%arg5, 0] [16, 32] [1, 1]
/// : memref<256x32xf32> to memref<16x32xf32, #map0>
/// %1 = subview %arg4[%arg5, 0] [16, 32] [1, 1]
/// : memref<256x32xf32> to memref<16x32xf32, #map0>
/// %2 = subview %arg0[%arg5, 0] [16, 32] [1, 1]
/// : memref<256x32xf32> to memref<16x32xf32, #map0>
/// %3 = subview %arg1[0, 0] [32, 32] [1, 1]
/// : memref<32x32xf32> to memref<32x32xf32, #map1>
/// %4 = subview %arg3[0, 0] [32, 32] [1, 1]
/// : memref<32x32xf32> to memref<32x32xf32, #map1>
/// linalg.matmul
/// ins(%2, %3 : memref<16x32xf32, #map0>, memref<32x32xf32, #map1>)
/// outs(%0 : memref<16x32xf32, #map0>)
/// linalg.matmul
/// ins(%0, %4 : memref<16x4xf32, #map0>, memref<4x8xf32, #map0>)
/// outs(%1 : memref<16x8xf32, #map0>)
/// }
///
/// `tilingOptions` are used to tile the corresponding operation in `ops` (the
/// size of the former should be same as size of the latter. Based on how
/// tile+fuse is implemented, the fused loops are generated based on the last
/// operation in the sequence. For example, the tile sizes for the fused loops
/// is obtained from `tilingOptions.back()`. The following tiling options are
/// handled differently in tile+fuse (compared to tile only)
/// - Interchange of the tiling loops is not supported right now.
/// - Only the fused loops are distributed.
struct TiledAndFusedLinalgOps {
/// Operation obtained by tiling the last operation in sequence of `ops`
/// passed to `tileAndFuseLinalgOps`.
LinalgOp op;
/// The dimension of the loops that are fused.
std::set<unsigned> fusedLoopDims;
/// The generated fused operations (created within the fused loops).
SmallVector<LinalgOp, 1> fusedProducers;
/// The fused loop generated.
SmallVector<Operation *, 4> fusedLoops;
};
FailureOr<TiledAndFusedLinalgOps>
tileAndFuseLinalgOps(OpBuilder &builder, ArrayRef<LinalgOp> ops,
const LinalgDependenceGraph &dependenceGraph,
const LinalgTilingOptions &tilingOptions);
/// Interchanges the `iterator_types` and `iterator_maps` dimensions and adapts
/// the index accesses of `op`. This is an in-place transformation controlled by
/// `interchangeVector`. An empty vector is interpreted as the identity
/// permutation and the transformation returns early.
///
/// E.g. the permutation `(i,j,k) -> (j,k,i)` is expressed with
/// `interchangeVector = [1,2,0]`. All values in `interchangeVector` must be
/// integers, in the range 0..`op.rank` without duplications
/// (i.e. `[1,1,2]` is an invalid permutation).
void interchangeGenericOp(PatternRewriter &rewriter, GenericOp genericOp,
ArrayRef<unsigned> interchangeVector);
/// Creates a GenericOp from the given named operation `namedOp`. Assumes
/// `namedOp` is not a GenericOp and has a region builder.
GenericOp generalizeNamedOp(PatternRewriter &rewriter, LinalgOp namedOp);
/// Callback function type used to perform the allocation for the promoted
/// `subView`. In `boundingSubViewsize` a best attempt is made to find the
/// smallest constant value for the size of the buffer needed for each
/// dimension. If that is not possible, contains the dynamic size of the
/// subview. The call back should return the buffer to use.
using AllocBufferCallbackFn = std::function<Optional<Value>(
OpBuilder &b, memref::SubViewOp subView,
ArrayRef<Value> boundingSubViewSize, DataLayout &layout)>;
/// Callback function type used to deallocate the buffers used to hold the
/// promoted subview.
using DeallocBufferCallbackFn =
std::function<LogicalResult(OpBuilder &b, Value buffer)>;
/// Callback function type used to insert copy from original subview to subview
/// of the promoted region for the read operands/subview of promoted region to
/// original subview for the results. The copy has to happen from `src` to
/// `dst`.
using CopyCallbackFn =
std::function<LogicalResult(OpBuilder &b, Value src, Value dst)>;
struct LinalgPromotionOptions {
/// Indices of subViews to promote. If `None`, try to promote all operands.
Optional<DenseSet<unsigned>> operandsToPromote = None;
LinalgPromotionOptions &setOperandsToPromote(ArrayRef<int64_t> operands) {
operandsToPromote = DenseSet<unsigned>();
operandsToPromote->insert(operands.begin(), operands.end());
return *this;
}
/// If ith element of `useFullTiles` is true the full view should be used for
/// the promoted buffer of the ith operand in `operandsToPromote`. Otherwise
/// the partial view will be used.
/// The decision is defaulted to `useFullTileBuffersDefault` when
/// `useFullTileBuffers` is None and for operands missing from
/// `useFullTileBuffers`.
Optional<llvm::SmallBitVector> useFullTileBuffers = None;
LinalgPromotionOptions &setUseFullTileBuffers(ArrayRef<bool> useFullTiles) {
unsigned size = useFullTiles.size();
llvm::SmallBitVector tmp(size, false);
for (unsigned i = 0; i < size; ++i)
tmp[i] = useFullTiles[i];
useFullTileBuffers = tmp;
return *this;
}
/// If true all operands unspecified by `useFullTileBuffers` will use the full
/// view, otherwise the partial view.
bool useFullTileBuffersDefault = false;
LinalgPromotionOptions &setUseFullTileBuffersByDefault(bool use) {
useFullTileBuffersDefault = use;
return *this;
}
/// Allow the use of dynamically-sized buffers.
bool dynamicBuffers = false;
LinalgPromotionOptions &setDynamicBuffers(unsigned dynamic) {
dynamicBuffers = dynamic;
return *this;
}
/// Alignment of promoted buffer. If `None` do not specify alignment.
Optional<unsigned> alignment = None;
LinalgPromotionOptions &setAlignment(unsigned align) {
alignment = align;
return *this;
}
/// Use alloca with the default allocation scheme.
bool useAlloca = false;
LinalgPromotionOptions &setUseAlloca(bool use) {
useAlloca = use;
return *this;
}
/// Callback function to do the allocation of the promoted buffer. If None,
/// then the default allocation scheme of allocating a memref<?xi8> buffer
/// followed by a view operation is used.
Optional<AllocBufferCallbackFn> allocationFn = None;
Optional<DeallocBufferCallbackFn> deallocationFn = None;
LinalgPromotionOptions &
setAllocationDeallocationFns(AllocBufferCallbackFn const &allocFn,
DeallocBufferCallbackFn const &deallocFn) {
allocationFn = allocFn;
deallocationFn = deallocFn;
return *this;
}
/// Callback function to do the copy of data to and from the promoted
/// subview. If None then a linalg.copy is used.
Optional<CopyCallbackFn> copyInFn = None;
Optional<CopyCallbackFn> copyOutFn = None;
LinalgPromotionOptions &setCopyInOutFns(CopyCallbackFn const &copyIn,
CopyCallbackFn const &copyOut) {
copyInFn = copyIn;
copyOutFn = copyOut;
return *this;
}
};
/// Creates a new buffer using the `allocationFn` provided. The size of this
/// buffer is the smallest constant bounding size along each dimension that can
/// be computed for the size of the result of `subView`. Returns the allocated
/// buffer as `fullLocalView` and the view that matches the size of the result
/// of subview operation as `partialLocalView`.
struct PromotionInfo {
Value fullLocalView;
Value partialLocalView;
};
FailureOr<PromotionInfo>
promoteSubviewAsNewBuffer(OpBuilder &b, Location loc, memref::SubViewOp subView,
AllocBufferCallbackFn allocationFn,
DataLayout &layout);
/// Promotes the `subViews` into a new buffer allocated at the insertion point
/// `b`. Promotion occurs in 3 steps:
/// 1. Create a new buffer for a full tile (i.e. not clipped at the boundary).
/// 2. Take a full view on the buffer.
/// 3. Take a partial slice of the full view in step 2. and copy into it.
/// Infers statically sized buffers from subViews unless `dynamicBuffers` is
/// true.
///
/// Returns the modified linalg op (the modification happens in place) as well
/// as all the copy ops created.
FailureOr<LinalgOp> promoteSubViews(OpBuilder &b, LinalgOp op,
LinalgPromotionOptions options);
/// Emit a suitable vector form for a Linalg op with fully static shape.
LogicalResult vectorizeLinalgOp(OpBuilder &builder, Operation *op,
SmallVectorImpl<Value> &newResults);
/// Emits a loop nest of `scf.for` with the proper body for `linalgOp`.
FailureOr<LinalgLoops> linalgOpToLoops(PatternRewriter &rewriter,
LinalgOp linalgOp);
/// Emits a loop nest of `scf.parallel` with the proper body for `linalgOp`.
FailureOr<LinalgLoops> linalgOpToParallelLoops(PatternRewriter &rewriter,
LinalgOp linalgOp);
/// Emits a loop nest of `affine.for` with the proper body for `linalgOp`.
FailureOr<LinalgLoops> linalgOpToAffineLoops(PatternRewriter &rewriter,
LinalgOp linalgOp);
//===----------------------------------------------------------------------===//
// Preconditions that ensure the corresponding transformation succeeds and can
// be applied as a rewrite pattern.
//===----------------------------------------------------------------------===//
/// Emits a `generic` operation with the `indexing_maps` and `iterator_types`
/// permutated according to `permutation`.
LogicalResult
interchangeGenericOpPrecondition(GenericOp genericOp,
ArrayRef<unsigned> interchangeVector);
/// Generalize named operations to generic operations.
LogicalResult generalizeNamedOpPrecondition(Operation *op);
/// Promote std.subviews feeding linalg operations.
LogicalResult promoteSubviewsPrecondition(Operation *op,
LinalgPromotionOptions options);
/// Rewrite a linalg.generic into a suitable vector.contraction op.
LogicalResult vectorizeLinalgOpPrecondition(Operation *op);
//===----------------------------------------------------------------------===//
// Transformations exposed as rewrite patterns.
//===----------------------------------------------------------------------===//
// Marker used as attribute name in generated Linalg rewriting transformations.
struct LinalgTransforms {
static const StringLiteral kLinalgTransformMarker;
};
/// Helper class to control application of linalg transformation patterns.
/// Control comes in 2 forms:
/// 1. attribute matching and setting behavior using the attribute named
/// `kLinalgTransformMarker`. This can be used to build a state machine
/// using attributes and incrementally applying patterns to advance states.
/// 2. filter function, which is a simple lambda on the Operation* that
/// returns a LogicalResult.
struct LinalgTransformationFilter {
using FilterFunction = std::function<LogicalResult(Operation *)>;
explicit LinalgTransformationFilter(
ArrayRef<StringAttr> matchDisjunction = {},
Optional<StringAttr> replacement = None);
explicit LinalgTransformationFilter(
FilterFunction f, ArrayRef<StringAttr> matchDisjunction = {},
Optional<StringAttr> replacement = None);
LinalgTransformationFilter(LinalgTransformationFilter &&) = default;
LinalgTransformationFilter(const LinalgTransformationFilter &) = default;
LogicalResult checkAndNotify(PatternRewriter &rewriter, Operation *op) const;
void replaceLinalgTransformationFilter(PatternRewriter &rewriter,
Operation *op) const;
bool hasReplacementFilter(Operation *op) const;
LinalgTransformationFilter &addFilter(FilterFunction f) {
if (f)
filters.push_back(f);
return *this;
}
template <typename... OpTypes>
LinalgTransformationFilter &addOpFilter() {
return addFilter(
[](Operation *op) { return success(isa<OpTypes...>(op)); });
}
LinalgTransformationFilter &setMatchByDefault() {
matchByDefault = true;
return *this;
}
private:
SmallVector<FilterFunction> filters;
SmallVector<StringAttr> matchDisjunction;
Optional<StringAttr> replacement;
/// When set to true, if the attribute is not set, it will be treated as
/// a match. Default is false.
bool matchByDefault;
};
using TileSizeComputationFunction =
std::function<SmallVector<Value, 4>(OpBuilder &, Operation *)>;
/// Callback returning the padding value to use for a given OpOperand or failure
/// for no padding. This should be a function of both the operation and the
/// operand type.
using PaddingValueComputationFunction =
std::function<FailureOr<Value>(OpBuilder &, OpOperand &)>;
/// Callback returning true if the pad tensor operation defining the given
/// OpOperand shall be marked as nofold to enable packing.
using PaddingNoFoldComputationFunction = std::function<bool(OpOperand &)>;
/// Callback returning the number of loops to hoist the pad tensor operation
/// defining the given OpOperand.
using PaddingHoistComputationFunction = std::function<int64_t(OpOperand &)>;
struct LinalgPaddingOptions {
/// Callback returning the padding value to use for a given OpOperand or
/// failure for no padding. Padding operations are introduced if
/// `paddingValueComputationFunction` is set and does not return failure.
/// Padding all operands guarantees the operation is statically shaped and
/// thus can be vectorized.
PaddingValueComputationFunction paddingValueComputationFunction = nullptr;
LinalgPaddingOptions &
setPaddingValueComputationFunction(PaddingValueComputationFunction fun) {
paddingValueComputationFunction = std::move(fun);
return *this;
}
/// Callback returning true if the pad tensor operation defining the given
/// OpOperand shall be marked as nofold to enable packing. A padding operation
/// is only marked nofold if `paddingNoFoldComputationFunction` is set and
/// returns true. Otherwise, the nofold attribute is set to false.
PaddingNoFoldComputationFunction paddingNoFoldComputationFunction = nullptr;
LinalgPaddingOptions &
setPaddingNoFoldComputationFunction(PaddingNoFoldComputationFunction fun) {
paddingNoFoldComputationFunction = std::move(fun);
return *this;
}
/// Callback returning the number of loops to hoist the pad tensor operation
/// defining the given OpOperand.
PaddingHoistComputationFunction paddingHoistComputationFunction = nullptr;
LinalgPaddingOptions &
setPaddingHoistComputationFunction(PaddingHoistComputationFunction fun) {
paddingHoistComputationFunction = std::move(fun);
return *this;
}
};
struct LinalgTilingAndFusionOptions {
/// Tile sizes used to tile the root operation.
SmallVector<int64_t> tileSizes;
/// Tile interchange used to permute the tile loops.
SmallVector<int64_t> tileInterchange;
};
struct LinalgTilingOptions {
/// Computation function that returns the tile sizes for each operation.
/// Delayed construction of constant tile sizes should occur to interoperate
/// with folding.
TileSizeComputationFunction tileSizeComputationFunction = nullptr;
LinalgTilingOptions &
setTileSizeComputationFunction(TileSizeComputationFunction fun) {
tileSizeComputationFunction = std::move(fun);
return *this;
}
/// Set the `tileSizeComputationFunction` to return the values `ts`. The
/// values must not fold away when tiling. Otherwise, use a more robust
/// `tileSizeComputationFunction`.
LinalgTilingOptions &setTileSizes(SmallVector<Value, 4> ts) {
tileSizeComputationFunction = [=](OpBuilder &, Operation *) { return ts; };
return *this;
}
/// Convenience function to set the `tileSizeComputationFunction` to a
/// function that computes tile sizes at the point they are needed. Allows
/// proper interaction with folding.
LinalgTilingOptions &setTileSizes(ArrayRef<int64_t> ts);
/// Tile all dynamic dimensions by 1. I.e., scalarize those dimensions.
/// Note: `scalarizeDynamicDims` and `setTileSizes` cannot be used together.
LinalgTilingOptions &scalarizeDynamicDims();
/// The interchange vector to reorder the tiled loops.
SmallVector<unsigned, 4> interchangeVector = {};
LinalgTilingOptions &setInterchange(ArrayRef<unsigned> interchange) {
interchangeVector.assign(interchange.begin(), interchange.end());
return *this;
}
/// The type of tile loops to generate.
LinalgTilingLoopType loopType = LinalgTilingLoopType::Loops;
LinalgTilingOptions &setLoopType(LinalgTilingLoopType lt) {
loopType = lt;
return *this;
}
/// When specified, specifies distribution of generated tile loops to
/// processors.
Optional<LinalgLoopDistributionOptions> distribution = None;
LinalgTilingOptions &
setDistributionOptions(LinalgLoopDistributionOptions distributionOptions) {
distribution = std::move(distributionOptions);
return *this;
}
/// Specification markers of how to distribute the `linalg.tiled_loop`.
SmallVector<StringRef, 2> distributionTypes = {};
LinalgTilingOptions &setDistributionTypes(ArrayRef<StringRef> types) {
distributionTypes.assign(types.begin(), types.end());
return *this;
}
/// Peel the specified loops.
SmallVector<int64_t> peeledLoops;
LinalgTilingOptions &setPeeledLoops(ArrayRef<int64_t> loops) {
peeledLoops.clear();
peeledLoops.append(loops.begin(), loops.end());
return *this;
}
};
/// Canonicalization patterns relevant to apply after tiling patterns. These are
/// applied automatically by the tiling pass but need to be applied manually
/// when tiling is called programmatically.
RewritePatternSet getLinalgTilingCanonicalizationPatterns(MLIRContext *ctx);
void populateLinalgTilingCanonicalizationPatterns(RewritePatternSet &patterns);
/// Base pattern that applied the tiling transformation specified by `options`.
/// Abort and return failure in 2 cases:
/// 1. if the tiling specification is invalid and tiling fails to occur.
/// 2. if tiling occurs but `options.paddingValueComputationFunction` is set
/// and some operand shape cannot be bounded statically.
struct LinalgBaseTilingPattern : public RewritePattern {
// Entry point to match any LinalgOp OpInterface.
LinalgBaseTilingPattern(
MLIRContext *context, LinalgTilingOptions options,
LinalgTransformationFilter filter = LinalgTransformationFilter(),
PatternBenefit benefit = 1);
// Entry point to match a specific Linalg op.
LinalgBaseTilingPattern(
StringRef opName, MLIRContext *context, LinalgTilingOptions options,
LinalgTransformationFilter filter = LinalgTransformationFilter(),
PatternBenefit benefit = 1);
LogicalResult matchAndRewriteBase(Operation *op, PatternRewriter &rewriter,
TiledLinalgOp &result) const;
private:
/// LinalgTransformMarker handles special attribute manipulations.
LinalgTransformationFilter filter;
/// Options to control tiling;
LinalgTilingOptions options;
};
template <typename OpTy>
struct LinalgTilingPattern : public LinalgBaseTilingPattern {
/// SFINAE: This constructor can only trigger for concrete ops that have a
/// static `getOperationName` method.
template <typename ConcreateOpTy = OpTy>
LinalgTilingPattern(
MLIRContext *context, LinalgTilingOptions options,
LinalgTransformationFilter filter = LinalgTransformationFilter(),
PatternBenefit benefit = 1)
: LinalgBaseTilingPattern(ConcreateOpTy::getOperationName(), context,
options, filter, benefit) {}
/// This constructor is available to anyone.
LinalgTilingPattern(
StringRef opName, MLIRContext *context, LinalgTilingOptions options,
LinalgTransformationFilter filter = LinalgTransformationFilter(),
PatternBenefit benefit = 1)
: LinalgBaseTilingPattern(opName, context, options, filter, benefit) {}
LogicalResult matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const override {
TiledLinalgOp tiledLinalgOp;
if (failed(LinalgBaseTilingPattern::matchAndRewriteBase(op, rewriter,
tiledLinalgOp)))
return failure();
if (tiledLinalgOp.tensorResults.empty())
rewriter.eraseOp(op);
else
rewriter.replaceOp(op, tiledLinalgOp.tensorResults);
return success();
}
};
struct LinalgGenericTilingPattern : public LinalgBaseTilingPattern {
/// Entry point to match any LinalgOp OpInterface.
/// MatchAnyOpTag-based constructor with a mandatory `filter`.
LinalgGenericTilingPattern(
MLIRContext *context, LinalgTransformationFilter filter,
LinalgTilingOptions options = LinalgTilingOptions(),
PatternBenefit benefit = 1)
: LinalgBaseTilingPattern(context, options, filter, benefit) {}
/// Entry point to match a specific Linalg op.
LinalgGenericTilingPattern(
StringRef opName, MLIRContext *context, LinalgTilingOptions options,
LinalgTransformationFilter filter = LinalgTransformationFilter(),
PatternBenefit benefit = 1)
: LinalgBaseTilingPattern(opName, context, options, filter, benefit) {}
LogicalResult matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const override {
TiledLinalgOp tiledLinalgOp;
if (failed(LinalgBaseTilingPattern::matchAndRewriteBase(op, rewriter,
tiledLinalgOp)))
return failure();
if (tiledLinalgOp.tensorResults.empty())
rewriter.eraseOp(op);
else
rewriter.replaceOp(op, tiledLinalgOp.tensorResults);
return success();
}
};
///
/// Linalg padding pattern.
///
/// Apply the `padding` transformation as a pattern.
/// `filter` controls LinalgTransformMarker matching and update when specified.
/// See `padding` for more details.
struct LinalgPaddingPattern : public RewritePattern {
// Entry point to match any LinalgOp OpInterface.
LinalgPaddingPattern(
MLIRContext *context,
LinalgPaddingOptions options = LinalgPaddingOptions(),
LinalgTransformationFilter filter = LinalgTransformationFilter(),
PatternBenefit benefit = 1);
// Entry point to match a specific LinalgOp.
LinalgPaddingPattern(
StringRef opName, MLIRContext *context,
LinalgPaddingOptions options = LinalgPaddingOptions(),
LinalgTransformationFilter filter = LinalgTransformationFilter(),
PatternBenefit benefit = 1);
LogicalResult matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const override;
private:
/// LinalgTransformMarker handles special attribute manipulations.
LinalgTransformationFilter filter;
/// Options to control padding and hoisting.
LinalgPaddingOptions options;
};
struct LinalgFusionOptions {
/// List of operands indices to use for fusion.
llvm::SmallSet<unsigned, 1> indicesToFuse = {};
LinalgFusionOptions &setIndicesToFuse(ArrayRef<int64_t> operands) {
indicesToFuse.insert(operands.begin(), operands.end());
return *this;
}
};
struct LinalgBaseTileAndFusePattern : public RewritePattern {
LinalgBaseTileAndFusePattern(
StringRef opName, MLIRContext *context,
const LinalgDependenceGraph &dependenceGraph,
LinalgTilingOptions tilingOptions, LinalgFusionOptions fusionOptions,
LinalgTransformationFilter filter = LinalgTransformationFilter(),
LinalgTransformationFilter fusedOpMarker = LinalgTransformationFilter(),
LinalgTransformationFilter originalOpMarker =
LinalgTransformationFilter(),
PatternBenefit benefit = 1);
LogicalResult matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const override;
private:
/// Dependence graph needed for fusion.
const LinalgDependenceGraph &dependenceGraph;
/// Options to control tiling.
LinalgTilingOptions tilingOptions;
/// Options to control fusion.
LinalgFusionOptions fusionOptions;
/// Marker to control application of the pattern.
LinalgTransformationFilter filter;
/// Marker set on the fused op after tile and fuse.
LinalgTransformationFilter fusedOpMarker;
/// The dependenceGraph is not modifiable, i.e. if the Linalg operations used
/// to build the dependence graph changes then the dependenceGraph needs to be
/// recomputed right now. To not invalidate the dependenceGraph as
/// transformation happens, the original producer can be tagged with a filter
/// that can be later used to delete the original operations.
LinalgTransformationFilter originalOpMarker;
};
template <typename OpTy>
struct LinalgTileAndFusePattern : public LinalgBaseTileAndFusePattern {
LinalgTileAndFusePattern(
MLIRContext *context, const LinalgDependenceGraph &dependenceGraph,
LinalgTilingOptions tilingOptions, LinalgFusionOptions fusionOptions,
LinalgTransformationFilter filter = LinalgTransformationFilter(),
LinalgTransformationFilter fusedOpMarker = LinalgTransformationFilter(),
LinalgTransformationFilter originalOpMarker =
LinalgTransformationFilter(),
PatternBenefit benefit = 1)
: LinalgBaseTileAndFusePattern(
OpTy::getOperationName(), context, dependenceGraph, tilingOptions,
fusionOptions, filter, fusedOpMarker, originalOpMarker, benefit) {}
};
///
/// Linalg tile and fuse tensor ops pattern.
///
/// Apply tiling and fusion as a pattern.
/// `filter` controls LinalgTransformMarker matching and update when specified.
/// See `tileConsumerAndFuseProducers` for more details.
struct LinalgTileAndFuseTensorOpsPattern : public RewritePattern {
// Entry point to match any LinalgOp.
LinalgTileAndFuseTensorOpsPattern(
MLIRContext *context, LinalgTilingAndFusionOptions options,
LinalgTransformationFilter filter = LinalgTransformationFilter(),
PatternBenefit benefit = 1);
// Entry point to match a specific LinalgOp.
LinalgTileAndFuseTensorOpsPattern(
StringRef opName, MLIRContext *context,
LinalgTilingAndFusionOptions options,
LinalgTransformationFilter filter = LinalgTransformationFilter(),
PatternBenefit benefit = 1);
LogicalResult matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const override;
private:
/// LinalgTransformMarker handles special attribute manipulations.
LinalgTransformationFilter filter;
/// Tile sizes and interchange used to tile the root operation.
LinalgTilingAndFusionOptions options;
};
///
/// Linalg generic interchage pattern.
///
/// Apply the `interchange` transformation as a pattern.
/// `filter` controls LinalgTransformMarker matching and update when specified.
/// See `interchange` for more details.
struct GenericOpInterchangePattern : public OpRewritePattern<GenericOp> {
using OpRewritePattern<GenericOp>::OpRewritePattern;
GenericOpInterchangePattern(
MLIRContext *context, ArrayRef<unsigned> interchangeVector,
LinalgTransformationFilter filter = LinalgTransformationFilter(),
PatternBenefit benefit = 1);
LogicalResult matchAndRewrite(GenericOp genericOp,
PatternRewriter &rewriter) const override;
private:
/// LinalgTransformMarker handles special attribute manipulations.
LinalgTransformationFilter filter;
/// The interchange vector to reorder the iterators and indexing_maps dims.
SmallVector<unsigned, 8> interchangeVector;
};
///
/// Linalg generalization pattern.
///
/// Apply the `generalization` transformation as a pattern.
/// `filter` controls LinalgTransformMarker matching and update when specified.
/// See `generalization` for more details.
struct LinalgGeneralizationPattern : public RewritePattern {
// Entry point to match any LinalgOp OpInterface.
LinalgGeneralizationPattern(
MLIRContext *context,
LinalgTransformationFilter filter = LinalgTransformationFilter(),
PatternBenefit benefit = 1);
// Entry point to match a specific Linalg op.
LinalgGeneralizationPattern(
StringRef opName, MLIRContext *context,
LinalgTransformationFilter filter = LinalgTransformationFilter(),
PatternBenefit benefit = 1);
LogicalResult matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const override;
private:
/// LinalgTransformMarker handles special attribute manipulations.
LinalgTransformationFilter filter;
};
///
/// Linalg promotion patterns.
///
/// Apply the `promoteSubViews` transformation as a pattern.
/// `filter` controls LinalgTransformMarker matching and update when specified.
/// See `promoteSubViews` for more details.
struct LinalgBasePromotionPattern : public RewritePattern {
/// Entry point to match any LinalgOp OpInterface.
/// MatchAnyOpTag-based constructor with a mandatory `filter`.
LinalgBasePromotionPattern(
MLIRContext *context, LinalgTransformationFilter filter,
LinalgPromotionOptions options = LinalgPromotionOptions(),
PatternBenefit benefit = 1);
/// Entry point to match a specific Linalg op.
LinalgBasePromotionPattern(
StringRef opName, MLIRContext *context, LinalgPromotionOptions options,
LinalgTransformationFilter filter = LinalgTransformationFilter(),
PatternBenefit benefit = 1);
LogicalResult matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const override;
private:
/// LinalgTransformMarker handles special attribute manipulations.
LinalgTransformationFilter filter;
/// Promotion options.
LinalgPromotionOptions options;
};
template <typename OpTy>
struct LinalgPromotionPattern : public LinalgBasePromotionPattern {
/// SFINAE: This constructor can only trigger for concrete ops that have a
/// static `getOperationName` method.
template <typename ConcreateOpTy = OpTy>
LinalgPromotionPattern(
MLIRContext *context, LinalgPromotionOptions options,
LinalgTransformationFilter filter = LinalgTransformationFilter(),
PatternBenefit benefit = 1)
: LinalgBasePromotionPattern(OpTy::getOperationName(), context, options,
filter, benefit) {}
/// This constructor is available to anyone.
LinalgPromotionPattern(
StringRef opName, MLIRContext *context, LinalgPromotionOptions options,
LinalgTransformationFilter filter = LinalgTransformationFilter(),
PatternBenefit benefit = 1)
: LinalgBasePromotionPattern(opName, context, options, filter, benefit) {}
};
///
/// Linalg vectorization patterns.
///
/// Apply the `vectorizeLinalgOp` transformation as a pattern.
/// `filter` controls LinalgTransformMarker matching and update when specified.
/// See `vectorizeLinalgOp` for more details.
/// Empty for now, used for SFINAE purposes only.
struct LinalgVectorizationOptions {};
struct LinalgBaseVectorizationPattern : public RewritePattern {
/// MatchAnyOpTag-based constructor with a mandatory `filter`.
LinalgBaseVectorizationPattern(MLIRContext *context,
LinalgTransformationFilter filter,
PatternBenefit benefit = 1);
/// Name-based constructor with an optional `filter`.
LinalgBaseVectorizationPattern(
StringRef opName, MLIRContext *context,
LinalgTransformationFilter filter = LinalgTransformationFilter(),
PatternBenefit benefit = 1);
LogicalResult matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const override;
private:
/// LinalgTransformMarker handles special attribute manipulations.
LinalgTransformationFilter filter;
};
struct LinalgVectorizationPattern : public LinalgBaseVectorizationPattern {
/// These constructors are available to anyone.
/// MatchAnyOpTag-based constructor with a mandatory `filter`.
LinalgVectorizationPattern(
MLIRContext *context, LinalgTransformationFilter filter,
LinalgVectorizationOptions options = LinalgVectorizationOptions(),
PatternBenefit benefit = 1)
: LinalgBaseVectorizationPattern(context, filter, benefit) {}
/// Name-based constructor with an optional `filter`.
LinalgVectorizationPattern(
StringRef opName, MLIRContext *context,
LinalgVectorizationOptions options = LinalgVectorizationOptions(),
LinalgTransformationFilter filter = LinalgTransformationFilter(),
PatternBenefit benefit = 1)
: LinalgBaseVectorizationPattern(opName, context, filter, benefit) {}
};
//===----------------------------------------------------------------------===//
// Transformation and lowering options exposed as auxiliary structs.
//===----------------------------------------------------------------------===//
/// Options to control the application of enabling transformations.
/// Hoisting transformations are always deemed beneficial and must be disabled
/// explicitly.
struct LinalgEnablingOptions {
/// Enable LICM.
bool licm = true;
LinalgEnablingOptions &enableLICM(bool val = true) {
licm = val;
return *this;
}
/// Enable hoisting of redundant vector transfer ops.
bool hoistRedundantVectorTransfers = true;
LinalgEnablingOptions &enableHoistRedundantVectorTransfers(bool val = true) {
hoistRedundantVectorTransfers = val;
return *this;
}
/// Enable hoisting of redundant vector transfer ops on tensor.
bool hoistRedundantVectorTransfersOnTensor = true;
LinalgEnablingOptions &
enableHoistRedundantVectorTransfersOnTensor(bool val = true) {
hoistRedundantVectorTransfersOnTensor = val;
return *this;
}
};
/// Vector lowering options control how ops are lowered down to 1-D and scf.for
/// form.
struct LinalgVectorLoweringOptions {
/// Enable lowering of vector.contract.
/// In a progressive lowering of vectors, this would be the 1st step.
bool contractionLowering = false;
LinalgVectorLoweringOptions &enableContractionLowering(bool val = true) {
contractionLowering = val;
return *this;
}
/// Enable lowering of vector.multi_reduce.
/// In a progressive lowering of vectors, this would be the 2nd step.
bool multiReductionLowering = false;
LinalgVectorLoweringOptions &enableMultiReductionLowering(bool val = true) {
multiReductionLowering = val;
return *this;
}
/// Trigger full / partial vector.transfer splits.
/// In a progressive lowering of vectors, this would be the 3rd step.
bool transferPartialRewrite = false;
LinalgVectorLoweringOptions &enableTransferPartialRewrite(bool val = true) {
transferPartialRewrite = val;
return *this;
}
/// Enable lowering of vector.transfer to scf.
/// In a progressive lowering of vectors, this would be the 4th step.
bool transferToSCFConversion = false;
LinalgVectorLoweringOptions &enableTransferToSCFConversion(bool val = true) {
transferToSCFConversion = val;
return *this;
}
/// Maximal transfer rank under which we do not lower further.
int64_t maxTransferRank = 1;
LinalgVectorLoweringOptions &setMaxTransferRank(int64_t val) {
maxTransferRank = val;
return *this;
}
/// Vector lowering operations may result in surprising behavior when
/// composing multiple codegen strategies and must be enabled explicitly.
/// In a progressive lowering of vectors, this would be the 5th step.
bool transferLowering = true;
LinalgVectorLoweringOptions &enableTransferLowering(bool val = true) {
transferLowering = val;
return *this;
}
/// Enable lowering of vector.shape_cast to insert/extract.
/// In a progressive lowering of vectors, this would be the 6th step.
bool shapeCastLowering = true;
LinalgVectorLoweringOptions &enableShapeCastLowering(bool val = true) {
shapeCastLowering = val;
return *this;
}
/// Enable lowering of vector.transpose.
/// In a progressive lowering of vectors, this would be the 7th step.
bool transposeLowering = false;
LinalgVectorLoweringOptions &enableVectorTransposeLowering(bool val = true) {
transposeLowering = val;
return *this;
}
/// Enable AVX2-specific lowerings.
bool avx2Lowering = false;
LinalgVectorLoweringOptions &enableAVX2Lowering(bool val = true) {
avx2Lowering = val;
return *this;
}
/// Configure the post staged-patterns late vector.transfer to scf
/// conversion.
VectorTransferToSCFOptions vectorTransferToSCFOptions;
LinalgVectorLoweringOptions &
setVectorTransferToSCFOptions(VectorTransferToSCFOptions options) {
vectorTransferToSCFOptions = options;
return *this;
}
/// Configure late vector transformations.
vector::VectorTransformsOptions vectorTransformOptions;
LinalgVectorLoweringOptions &
setVectorTransformsOptions(vector::VectorTransformsOptions options) {
vectorTransformOptions = options;
return *this;
}
/// Configure specialized vector lowerings.
x86vector::avx2::LoweringOptions avx2LoweringOptions;
LinalgVectorLoweringOptions &
setAVX2LoweringOptions(x86vector::avx2::LoweringOptions options) {
avx2LoweringOptions = options;
return *this;
}
};
//===----------------------------------------------------------------------===//
// Transformations exposed as rewrite patterns.
//===----------------------------------------------------------------------===//
/// Trait to check if T provides a `getOperationName` method.
template <typename T, typename... Args>
using has_get_operation_name = decltype(T::getOperationName());
template <typename T>
using detect_has_get_operation_name =
llvm::is_detected<has_get_operation_name, T>;
/// SFINAE helper for single C++ op with a `getOperationName` method.
template <
typename OpType,
typename = std::enable_if_t<detect_has_get_operation_name<OpType>::value>,
typename = void>
void insertVectorizationPatternImpl(RewritePatternSet &patternList,
linalg::LinalgVectorizationOptions options,
linalg::LinalgTransformationFilter f) {
patternList.add<linalg::LinalgVectorizationPattern>(
OpType::getOperationName(), patternList.getContext(), options, f);
}
/// SFINAE helper for single C++ class without a `getOperationName` method (e.g.
/// an OpInterface).
template <typename OpType, typename = std::enable_if_t<
!detect_has_get_operation_name<OpType>::value>>
void insertVectorizationPatternImpl(RewritePatternSet &patternList,
linalg::LinalgVectorizationOptions options,
linalg::LinalgTransformationFilter f) {
patternList.add<linalg::LinalgVectorizationPattern>(
patternList.getContext(), f.addOpFilter<OpType>(), options);
}
/// Variadic helper function to insert vectorization patterns for C++ ops.
template <typename... OpTypes>
void insertVectorizationPatterns(RewritePatternSet &patternList,
linalg::LinalgVectorizationOptions options,
linalg::LinalgTransformationFilter f =
linalg::LinalgTransformationFilter()) {
// FIXME: In c++17 this can be simplified by using 'fold expressions'.
(void)std::initializer_list<int>{
0,
(insertVectorizationPatternImpl<OpTypes>(patternList, options, f), 0)...};
}
///
/// Linalg lowering patterns.
///
/// Apply the `linalgLowerOpToLoops` transformation as a pattern.
/// `filter` controls LinalgTransformMarker matching and update when specified.
/// See `linalgLowerOpToLoops` for more details.
enum class LinalgLoweringType {
LibraryCall = 0,
Loops = 1,
AffineLoops = 2,
ParallelLoops = 3
};
template <typename OpTy>
struct LinalgLoweringPattern : public RewritePattern {
LinalgLoweringPattern(
MLIRContext *context, LinalgLoweringType loweringType,
LinalgTransformationFilter filter = LinalgTransformationFilter(),
PatternBenefit benefit = 1)
: RewritePattern(OpTy::getOperationName(), benefit, context),
filter(filter), loweringType(loweringType) {}
// TODO: Move implementation to .cpp once named ops are auto-generated.
LogicalResult matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const override {
LinalgOp linalgOp = dyn_cast<LinalgOp>(op);
if (!linalgOp)
return failure();
if (failed(filter.checkAndNotify(rewriter, linalgOp)))
return failure();
switch (loweringType) {
case LinalgLoweringType::LibraryCall:
// TODO: Move lowering to library calls here.
return failure();
case LinalgLoweringType::Loops:
if (failed(linalgOpToLoops(rewriter, op)))
return failure();
break;
case LinalgLoweringType::AffineLoops:
if (failed(linalgOpToAffineLoops(rewriter, op)))
return failure();
break;
case LinalgLoweringType::ParallelLoops:
if (failed(linalgOpToParallelLoops(rewriter, op)))
return failure();
break;
}
rewriter.eraseOp(op);
return success();
}
private:
/// LinalgTransformMarker handles special attribute manipulations.
LinalgTransformationFilter filter;
/// Controls whether the pattern lowers to library calls, scf.for, affine.for
/// or scf.parallel.
LinalgLoweringType loweringType;
};
/// Linalg generalization patterns
/// Populates `patterns` with patterns to convert spec-generated named ops to
/// linalg.generic ops.
void populateLinalgNamedOpsGeneralizationPatterns(
RewritePatternSet &patterns,
LinalgTransformationFilter filter = LinalgTransformationFilter());
/// Linalg decompose convolutions patterns
/// Populates patterns to decompose high-D convolution ops into low-D ones. This
/// is a step in progressive lowering for convolution ops, afterwards we can
/// vectorize the low-D convolution ops.
void populateDecomposeConvolutionPatterns(
RewritePatternSet &patterns,
LinalgTransformationFilter filter = LinalgTransformationFilter(),
PatternBenefit benefit = 1);
/// Linalg distribution patterns
//
/// Populates `patterns` with patterns to distribute linalg.tiled_loop.
void populateLinalgDistributeTiledLoopPattern(
RewritePatternSet &patterns, const LinalgLoopDistributionOptions &opts,
const LinalgTransformationFilter &marker);
//===----------------------------------------------------------------------===//
// Op-specific patterns.
//===----------------------------------------------------------------------===//
/// PadTensorOp is not canonicalized away yet, so we provide a transformation to
/// `linalg.generic`.
struct PadTensorOpTransformationPattern : public OpRewritePattern<PadTensorOp> {
using OpRewritePattern<PadTensorOp>::OpRewritePattern;
LogicalResult matchAndRewrite(PadTensorOp padOp,
PatternRewriter &rewriter) const override;
};
/// Pad the operands of `opToPad` to a static bounding box. Use `paddingFunc`
/// and `nofoldFunc` to set the padding value and the nofold attribute of the
/// introduced PadTensorOps, respectively. Update `paddedOp` to the cloned
/// statically shaped operation and return the extracted dynamically shaped
/// results. If padding fails, return failure.
FailureOr<SmallVector<Value>>
rewriteAsPaddedOp(OpBuilder &b, LinalgOp opToPad,
const PaddingValueComputationFunction &paddingFunc,
const PaddingNoFoldComputationFunction &nofoldFunc,
LinalgOp &paddedOp);
using OptimizeCopyFn =
std::function<LogicalResult(PatternRewriter &, PadTensorOp, Value)>;
/// Rewrite a PadTensorOp into a sequence of InitTensorOp, FillOp and
/// InsertSliceOp. For now, only constant padding values are supported.
/// `OptimizeCopyFn` can be used to customize copying step optimization.
struct GeneralizePadTensorOpPattern : public OpRewritePattern<PadTensorOp> {
GeneralizePadTensorOpPattern(MLIRContext *context,
OptimizeCopyFn optimizeCopyFn = nullptr,
PatternBenefit benefit = 1)
: OpRewritePattern<PadTensorOp>(context, benefit),
optimizeCopyFn(optimizeCopyFn) {}
LogicalResult matchAndRewrite(PadTensorOp padOp,
PatternRewriter &rewriter) const override;
protected:
OptimizeCopyFn optimizeCopyFn;
Value createFillOrGenerateOp(PatternRewriter &rewriter, PadTensorOp padOp,
Value dest,
const SmallVector<Value> &dynSizes) const;
};
/// Populates `patterns` with patterns that vectorize linalg.pad_tensor.
/// These patterns are meant to apply in a complementary fashion. Benefits
/// are used to encode a certain ordering of pattern application. To avoid
/// scattering magic constants throughout the code base, the patterns must be
/// added with this function. `baseBenefit` can be used to offset the benefit
/// of all PadTensorOp vectorization patterns by a certain value.
void populatePadTensorOpVectorizationPatterns(RewritePatternSet &patterns,
PatternBenefit baseBenefit = 1);
/// Match and rewrite for the pattern:
/// ```
/// %alloc = ...
/// [optional] %view = memref.view %alloc ...
/// %subView = subview %allocOrView ...
/// [optional] linalg.fill(%allocOrView, %cst) ...
/// ...
/// linalg.copy(%in, %subView) ...
/// vector.transfer_read %allocOrView[...], %cst ...
/// ```
/// into
/// ```
/// [unchanged] %alloc = ...
/// [unchanged] [optional] %view = memref.view %alloc ...
/// [unchanged] [unchanged] %subView = subview %allocOrView ...
/// ...
/// vector.transfer_read %in[...], %cst ...
/// ```
/// Where there is no interleaved use between linalg.copy and transfer_read as
/// well as no interleaved use between linalg.fill and linalg.copy (if
/// linalg.fill is specified).
/// This is a custom rewrite to forward partial reads (with optional fills) to
/// vector.transfer_read.
struct LinalgCopyVTRForwardingPattern
: public OpRewritePattern<vector::TransferReadOp> {
using OpRewritePattern<vector::TransferReadOp>::OpRewritePattern;
LogicalResult matchAndRewrite(vector::TransferReadOp xferOp,
PatternRewriter &rewriter) const override;
};
/// Match and rewrite for the pattern:
/// ```
/// %alloc = ...
/// [optional] %view = memref.view %alloc ...
/// %subView = subview %allocOrView...
/// ...
/// vector.transfer_write %..., %allocOrView[...]
/// linalg.copy(%subView, %out)
/// ```
/// into
/// ```
/// [unchanged] %alloc = ...
/// [unchanged] [optional] %view = memref.view %alloc ...
/// [unchanged] %subView = subview %allocOrView...
/// ...
/// vector.transfer_write %..., %out[...]
/// ```
/// Where there is no interleaved use between transfer_write and linalg.copy.
/// This is a custom rewrite to forward partial writes to vector.transfer_write.
struct LinalgCopyVTWForwardingPattern
: public OpRewritePattern<vector::TransferWriteOp> {
using OpRewritePattern<vector::TransferWriteOp>::OpRewritePattern;
LogicalResult matchAndRewrite(vector::TransferWriteOp xferOp,
PatternRewriter &rewriter) const override;
};
/// Converts Convolution op into vector contraction.
///
/// Conversion expects ConvOp to have dimensions marked in the *mask* as
/// false of size 1. This ensures that the ConvOp can be lowered to vector
/// contraction of dimensions marked in the *mask* as true.
///
/// A good example for vectorization is ConvNHWCOp which is 2D Conv op
/// with channels as the last dimension. Let's vectorize last 3 dimensions.
/// The initial op definition looks like this:
/// ```
/// linalg.conv_2d_nhwc %arg0, %arg1, %arg2 :
/// (memref<1x3x3x3xf32>, memref<1x3x3x3xf32>, memref<?x?x?x?xf32>)
/// ```
/// This op can be expressed as a dot product between %arg0 (input) and
/// %arg1 (kernel) which is written into first entry of %arg2 (output). This is
/// the ConvOp this pass expects and converts into:
/// ```
/// #map0 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
/// #map1 = affine_map<(d0, d1, d2) -> ()>
/// .....
/// %0 = vector.transfer_read %arg0[%c0, %c0, %c0, %c0], %c0_f32
/// : memref<1x3x3x3xf32>, vector<3x3x3xf32>
/// %1 = vector.transfer_read %arg1[%c0, %c0, %c0, %c0], %c0_f32
/// : memref<1x3x3x3xf32>, vector<3x3x3xf32>
/// %2 = vector.contract {indexing_maps = [#map0, #map0, #map1],
/// iterator_types = ["reduction", "reduction", "reduction"]} %0, %1,
/// %c0_f32 : vector<3x3x3xf32>, vector<3x3x3xf32> into f32
/// store %2, %arg2[%c0, %c0, %c0, %c0] : memref<?x?x?x?xf32>
/// ```
/// where first 2 operations read input and kernel memory buffers into vectors.
/// Subsequently, they are contracted together and the result is written to
/// the first entry of the output buffer.
template <typename ConvOp, int N>
class ConvOpVectorization : public OpRewritePattern<ConvOp> {
using OpRewritePattern<ConvOp>::OpRewritePattern;
SmallVector<bool, 4> mask;
public:
ConvOpVectorization(MLIRContext *context, SmallVector<bool, 4> msk)
: OpRewritePattern<ConvOp>(context) {
assert(msk.size() == N && "Mask size does not match rank");
this->mask = msk;
}
LogicalResult matchAndRewrite(ConvOp minOp,
PatternRewriter &rewriter) const override;
};
/// Rewrite a TiledLoopOp with bounds/step that potentially do not divide evenly
/// into a TiledLoopOp where the step divides the iteration space evenly,
/// followed by another TiledLoopOp for the last (partial) iteration (if any).
/// This transformation is called "loop peeling".
///
/// This function peels the `idx`-th loop of the TiledLoopOp. To tile all loops
/// in the loop nest, this function must be called multiple times.
///
/// After loop peeling, this function tries to simplify/canonicalize affine.min
/// and affine.max ops in the body of the two TiledLoopOps. For more details,
/// refer to `mlir::scf::peelAndCanonicalizeForLoop`.
///
/// The return value indicates whether the loop was rewritten or not. Loops are
/// not rewritten if:
/// * Loop step size is 1 or
/// * Loop bounds and step size are static, and step already divides the
/// iteration space evenly.
///
/// Note: This function rewrites the given TiledLoopOp in-place and clones the
/// TileLoopOp operation for the last iteration. It replaces all uses of the
/// unpeeled TiledLoopOp with the results of the newly generated TiledLoopOp.
LogicalResult peelAndCanonicalizeTiledLoop(RewriterBase &rewriter,
TiledLoopOp loopOp, int64_t idx,
TiledLoopOp &result);
//===----------------------------------------------------------------------===//
// Support for staged pattern application.
//===----------------------------------------------------------------------===//
/// Helper function to allow applying rewrite patterns, interleaved with more
/// global transformations, in a staged fashion:
/// 1. the first stage consists of a list of FrozenRewritePatternSet. Each
/// FrozenRewritePatternSet in this list is applied once, in order.
/// 2. the second stage consists of a single OwningRewritePattern that is
/// applied greedily until convergence.
/// 3. the third stage consists of applying a lambda, generally used for
/// non-local transformation effects. This allows creating custom fused
/// transformations where patterns can be ordered and applied at a finer
/// granularity than a sequence of traditional compiler passes.
LogicalResult applyStagedPatterns(
Operation *op, ArrayRef<FrozenRewritePatternSet> stage1Patterns,
const FrozenRewritePatternSet &stage2Patterns,
function_ref<LogicalResult(Operation *)> stage3Lambda = nullptr);
/// Rewrite extract_slice(pad_tensor(x)) into pad_tensor(extract_slice(x)).
struct ExtractSliceOfPadTensorSwapPattern
: public OpRewritePattern<tensor::ExtractSliceOp> {
using OpRewritePattern<tensor::ExtractSliceOp>::OpRewritePattern;
LogicalResult matchAndRewrite(tensor::ExtractSliceOp sliceOp,
PatternRewriter &rewriter) const override;
};
} // namespace linalg
} // namespace mlir
#endif // DIALECT_LINALG_TRANSFORMS_TRANSFORMS_H_