| //===- Utils.cpp - Utilities to support the Tensor dialect ----------------===// |
| // |
| // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| // See https://llvm.org/LICENSE.txt for license information. |
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| // |
| //===----------------------------------------------------------------------===// |
| // |
| // This file implements utilities for the Tensor dialect. |
| // |
| //===----------------------------------------------------------------------===// |
| |
| #include "mlir/Dialect/Tensor/Utils/Utils.h" |
| |
| #include "mlir/Dialect/Affine/IR/AffineOps.h" |
| #include "mlir/Dialect/Arith/IR/Arith.h" |
| #include "mlir/Dialect/Arith/Utils/Utils.h" |
| #include "mlir/Dialect/Utils/IndexingUtils.h" |
| #include "mlir/Interfaces/ValueBoundsOpInterface.h" |
| |
| using namespace mlir; |
| using namespace mlir::tensor; |
| |
| PadOp mlir::tensor::createPadHighOp(RankedTensorType type, Value source, |
| Value pad, bool nofold, Location loc, |
| OpBuilder &b) { |
| SmallVector<OpFoldResult> low(type.getRank(), b.getIndexAttr(0)); |
| SmallVector<OpFoldResult> high(type.getRank(), b.getIndexAttr(0)); |
| for (const auto &en : enumerate(type.getShape())) { |
| // Pad only the static dimensions of the result tensor type. |
| if (ShapedType::isDynamic(en.value())) |
| continue; |
| // Compute the padding width. |
| AffineExpr d0; |
| bindDims(b.getContext(), d0); |
| OpFoldResult sz = tensor::getMixedSize(b, loc, source, en.index()); |
| high[en.index()] = |
| affine::makeComposedFoldedAffineApply(b, loc, en.value() - d0, {sz}); |
| } |
| return b.create<PadOp>(loc, type, source, low, high, pad, nofold); |
| } |
| |
| SmallVector<Value> mlir::tensor::createDynamicDimValues(OpBuilder &b, |
| Location loc, |
| Value rankedTensor) { |
| auto tensorTy = cast<RankedTensorType>(rankedTensor.getType()); |
| SmallVector<Value> dynamicDims; |
| for (const auto &en : llvm::enumerate(tensorTy.getShape())) { |
| if (en.value() == ShapedType::kDynamic) |
| dynamicDims.push_back( |
| b.create<tensor::DimOp>(loc, rankedTensor, en.index())); |
| } |
| return dynamicDims; |
| } |
| |
| FailureOr<RankedTensorType> |
| mlir::tensor::computeTransposedType(RankedTensorType rankedTensorType, |
| ArrayRef<int64_t> transposeVector) { |
| if (transposeVector.empty()) |
| return rankedTensorType; |
| |
| if (!isPermutationVector(transposeVector) || |
| transposeVector.size() != static_cast<size_t>(rankedTensorType.getRank())) |
| return failure(); |
| |
| SmallVector<int64_t> transposedShape(rankedTensorType.getShape().begin(), |
| rankedTensorType.getShape().end()); |
| applyPermutationToVector(transposedShape, transposeVector); |
| |
| using RTTBuilder = RankedTensorType::Builder; |
| RankedTensorType transposedTensorType = |
| RTTBuilder(rankedTensorType).setShape(transposedShape); |
| return transposedTensorType; |
| } |
| |
| bool mlir::tensor::isCastLikeInsertSliceOp(InsertSliceOp op) { |
| llvm::SmallBitVector droppedDims = op.getDroppedDims(); |
| int64_t srcDim = 0; |
| // Source dims and destination dims (apart from dropped dims) must have the |
| // same size. |
| for (int64_t resultDim = 0; resultDim < op.getDestType().getRank(); |
| ++resultDim) { |
| if (droppedDims.test(resultDim)) { |
| continue; |
| } |
| FailureOr<bool> equalDimSize = ValueBoundsConstraintSet::areEqual( |
| op.getSource(), op.getResult(), srcDim, resultDim); |
| if (failed(equalDimSize) || !*equalDimSize) |
| return false; |
| ++srcDim; |
| } |
| |
| return true; |
| } |
| |
| bool mlir::tensor::isCastLikeExtractSliceOp(ExtractSliceOp op) { |
| llvm::SmallBitVector droppedDims = op.getDroppedDims(); |
| int64_t resultDim = 0; |
| // Source dims and result dims (apart from dropped dims) must have the same |
| // size. |
| RankedTensorType sourceType = op.getSourceType(); |
| for (int64_t dim = 0, e = sourceType.getRank(); dim < e; ++dim) { |
| if (droppedDims.test(dim)) { |
| // ExtractSlice may drop unit dimensions that result from taking a size-1 |
| // slice from a non-size-1 source dimension. |
| if (sourceType.getDimSize(dim) != 1) |
| return false; |
| continue; |
| } |
| FailureOr<bool> equalDimSize = ValueBoundsConstraintSet::areEqual( |
| op.getSource(), op.getResult(), dim, resultDim); |
| if (failed(equalDimSize) || !*equalDimSize) |
| return false; |
| ++resultDim; |
| } |
| |
| return true; |
| } |