| //===- ConcatOpPatterns.cpp - Patterns related to tensor.concat lowering --===// |
| // |
| // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| // See https://llvm.org/LICENSE.txt for license information. |
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| // |
| //===----------------------------------------------------------------------===// |
| |
| #include "mlir/Dialect/Tensor/IR/Tensor.h" |
| #include "mlir/Dialect/Tensor/Transforms/Transforms.h" |
| #include "mlir/IR/PatternMatch.h" |
| |
| using namespace mlir; |
| using namespace mlir::tensor; |
| |
| namespace { |
| |
| /// Decompose `tensor.concat` into `tensor.empty` and a chain of slice inserts. |
| /// |
| /// %concat = tensor.concat dim(1) %0, %1 : |
| /// (tensor<2x3xf32>, tensor<2x4xf32>) -> tensor<2x7xf32> |
| /// |
| /// Becomes |
| /// |
| /// %empty = tensor.empty() : tensor<2x7xf32> |
| /// %insert0 = tensor.insert_slice %0 into %empty[0, 0][2, 3][1, 1] |
| /// %concat = tensor.insert_slice %1 into %insert0[0, 3][2, 4][1, 1] |
| struct DecomposeTensorConcatOp : public OpRewritePattern<ConcatOp> { |
| using OpRewritePattern<ConcatOp>::OpRewritePattern; |
| |
| LogicalResult matchAndRewrite(ConcatOp concatOp, |
| PatternRewriter &rewriter) const override { |
| FailureOr<SmallVector<Value>> decomposed = |
| concatOp.decomposeOperation(rewriter); |
| if (failed(decomposed)) { |
| return rewriter.notifyMatchFailure( |
| concatOp, "failed to get the decomposed insert slices"); |
| } |
| rewriter.replaceOp(concatOp, decomposed.value()[0]); |
| return success(); |
| } |
| }; |
| |
| } // namespace |
| |
| void mlir::tensor::populateDecomposeTensorConcatPatterns( |
| RewritePatternSet &patterns) { |
| patterns.add<DecomposeTensorConcatOp>(patterns.getContext()); |
| } |