| //===- RuntimeOpVerification.cpp - Op Verification ------------------------===// |
| // |
| // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| // See https://llvm.org/LICENSE.txt for license information. |
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| // |
| //===----------------------------------------------------------------------===// |
| |
| #include "mlir/Dialect/MemRef/Transforms/RuntimeOpVerification.h" |
| |
| #include "mlir/Dialect/Affine/IR/AffineOps.h" |
| #include "mlir/Dialect/Arith/IR/Arith.h" |
| #include "mlir/Dialect/Arith/Utils/Utils.h" |
| #include "mlir/Dialect/ControlFlow/IR/ControlFlow.h" |
| #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" |
| #include "mlir/Dialect/MemRef/IR/MemRef.h" |
| #include "mlir/Dialect/MemRef/Utils/MemRefUtils.h" |
| #include "mlir/Dialect/Utils/IndexingUtils.h" |
| #include "mlir/Interfaces/RuntimeVerifiableOpInterface.h" |
| |
| using namespace mlir; |
| |
| namespace mlir { |
| namespace memref { |
| namespace { |
| struct CastOpInterface |
| : public RuntimeVerifiableOpInterface::ExternalModel<CastOpInterface, |
| CastOp> { |
| void generateRuntimeVerification(Operation *op, OpBuilder &builder, |
| Location loc) const { |
| auto castOp = cast<CastOp>(op); |
| auto srcType = cast<BaseMemRefType>(castOp.getSource().getType()); |
| |
| // Nothing to check if the result is an unranked memref. |
| auto resultType = dyn_cast<MemRefType>(castOp.getType()); |
| if (!resultType) |
| return; |
| |
| if (isa<UnrankedMemRefType>(srcType)) { |
| // Check rank. |
| Value srcRank = builder.create<RankOp>(loc, castOp.getSource()); |
| Value resultRank = |
| builder.create<arith::ConstantIndexOp>(loc, resultType.getRank()); |
| Value isSameRank = builder.create<arith::CmpIOp>( |
| loc, arith::CmpIPredicate::eq, srcRank, resultRank); |
| builder.create<cf::AssertOp>( |
| loc, isSameRank, |
| RuntimeVerifiableOpInterface::generateErrorMessage(op, |
| "rank mismatch")); |
| } |
| |
| // Get source offset and strides. We do not have an op to get offsets and |
| // strides from unranked memrefs, so cast the source to a type with fully |
| // dynamic layout, from which we can then extract the offset and strides. |
| // (Rank was already verified.) |
| int64_t dynamicOffset = ShapedType::kDynamic; |
| SmallVector<int64_t> dynamicShape(resultType.getRank(), |
| ShapedType::kDynamic); |
| auto stridedLayout = StridedLayoutAttr::get(builder.getContext(), |
| dynamicOffset, dynamicShape); |
| auto dynStridesType = |
| MemRefType::get(dynamicShape, resultType.getElementType(), |
| stridedLayout, resultType.getMemorySpace()); |
| Value helperCast = |
| builder.create<CastOp>(loc, dynStridesType, castOp.getSource()); |
| auto metadataOp = builder.create<ExtractStridedMetadataOp>(loc, helperCast); |
| |
| // Check dimension sizes. |
| for (const auto &it : llvm::enumerate(resultType.getShape())) { |
| // Static dim size -> static/dynamic dim size does not need verification. |
| if (auto rankedSrcType = dyn_cast<MemRefType>(srcType)) |
| if (!rankedSrcType.isDynamicDim(it.index())) |
| continue; |
| |
| // Static/dynamic dim size -> dynamic dim size does not need verification. |
| if (resultType.isDynamicDim(it.index())) |
| continue; |
| |
| Value srcDimSz = |
| builder.create<DimOp>(loc, castOp.getSource(), it.index()); |
| Value resultDimSz = |
| builder.create<arith::ConstantIndexOp>(loc, it.value()); |
| Value isSameSz = builder.create<arith::CmpIOp>( |
| loc, arith::CmpIPredicate::eq, srcDimSz, resultDimSz); |
| builder.create<cf::AssertOp>( |
| loc, isSameSz, |
| RuntimeVerifiableOpInterface::generateErrorMessage( |
| op, "size mismatch of dim " + std::to_string(it.index()))); |
| } |
| |
| // Get result offset and strides. |
| int64_t resultOffset; |
| SmallVector<int64_t> resultStrides; |
| if (failed(getStridesAndOffset(resultType, resultStrides, resultOffset))) |
| return; |
| |
| // Check offset. |
| if (resultOffset != ShapedType::kDynamic) { |
| // Static/dynamic offset -> dynamic offset does not need verification. |
| Value srcOffset = metadataOp.getResult(1); |
| Value resultOffsetVal = |
| builder.create<arith::ConstantIndexOp>(loc, resultOffset); |
| Value isSameOffset = builder.create<arith::CmpIOp>( |
| loc, arith::CmpIPredicate::eq, srcOffset, resultOffsetVal); |
| builder.create<cf::AssertOp>( |
| loc, isSameOffset, |
| RuntimeVerifiableOpInterface::generateErrorMessage( |
| op, "offset mismatch")); |
| } |
| |
| // Check strides. |
| for (const auto &it : llvm::enumerate(resultStrides)) { |
| // Static/dynamic stride -> dynamic stride does not need verification. |
| if (it.value() == ShapedType::kDynamic) |
| continue; |
| |
| Value srcStride = |
| metadataOp.getResult(2 + resultType.getRank() + it.index()); |
| Value resultStrideVal = |
| builder.create<arith::ConstantIndexOp>(loc, it.value()); |
| Value isSameStride = builder.create<arith::CmpIOp>( |
| loc, arith::CmpIPredicate::eq, srcStride, resultStrideVal); |
| builder.create<cf::AssertOp>( |
| loc, isSameStride, |
| RuntimeVerifiableOpInterface::generateErrorMessage( |
| op, "stride mismatch of dim " + std::to_string(it.index()))); |
| } |
| } |
| }; |
| |
| /// Verifies that the indices on load/store ops are in-bounds of the memref's |
| /// index space: 0 <= index#i < dim#i |
| template <typename LoadStoreOp> |
| struct LoadStoreOpInterface |
| : public RuntimeVerifiableOpInterface::ExternalModel< |
| LoadStoreOpInterface<LoadStoreOp>, LoadStoreOp> { |
| void generateRuntimeVerification(Operation *op, OpBuilder &builder, |
| Location loc) const { |
| auto loadStoreOp = cast<LoadStoreOp>(op); |
| |
| auto memref = loadStoreOp.getMemref(); |
| auto rank = memref.getType().getRank(); |
| if (rank == 0) { |
| return; |
| } |
| auto indices = loadStoreOp.getIndices(); |
| |
| auto zero = builder.create<arith::ConstantIndexOp>(loc, 0); |
| Value assertCond; |
| for (auto i : llvm::seq<int64_t>(0, rank)) { |
| auto index = indices[i]; |
| |
| auto dimOp = builder.createOrFold<memref::DimOp>(loc, memref, i); |
| |
| auto geLow = builder.createOrFold<arith::CmpIOp>( |
| loc, arith::CmpIPredicate::sge, index, zero); |
| auto ltHigh = builder.createOrFold<arith::CmpIOp>( |
| loc, arith::CmpIPredicate::slt, index, dimOp); |
| auto andOp = builder.createOrFold<arith::AndIOp>(loc, geLow, ltHigh); |
| |
| assertCond = |
| i > 0 ? builder.createOrFold<arith::AndIOp>(loc, assertCond, andOp) |
| : andOp; |
| } |
| builder.create<cf::AssertOp>( |
| loc, assertCond, |
| RuntimeVerifiableOpInterface::generateErrorMessage( |
| op, "out-of-bounds access")); |
| } |
| }; |
| |
| /// Compute the linear index for the provided strided layout and indices. |
| Value computeLinearIndex(OpBuilder &builder, Location loc, OpFoldResult offset, |
| ArrayRef<OpFoldResult> strides, |
| ArrayRef<OpFoldResult> indices) { |
| auto [expr, values] = computeLinearIndex(offset, strides, indices); |
| auto index = |
| affine::makeComposedFoldedAffineApply(builder, loc, expr, values); |
| return getValueOrCreateConstantIndexOp(builder, loc, index); |
| } |
| |
| /// Returns two Values representing the bounds of the provided strided layout |
| /// metadata. The bounds are returned as a half open interval -- [low, high). |
| std::pair<Value, Value> computeLinearBounds(OpBuilder &builder, Location loc, |
| OpFoldResult offset, |
| ArrayRef<OpFoldResult> strides, |
| ArrayRef<OpFoldResult> sizes) { |
| auto zeros = SmallVector<int64_t>(sizes.size(), 0); |
| auto indices = getAsIndexOpFoldResult(builder.getContext(), zeros); |
| auto lowerBound = computeLinearIndex(builder, loc, offset, strides, indices); |
| auto upperBound = computeLinearIndex(builder, loc, offset, strides, sizes); |
| return {lowerBound, upperBound}; |
| } |
| |
| /// Returns two Values representing the bounds of the memref. The bounds are |
| /// returned as a half open interval -- [low, high). |
| std::pair<Value, Value> computeLinearBounds(OpBuilder &builder, Location loc, |
| TypedValue<BaseMemRefType> memref) { |
| auto runtimeMetadata = builder.create<ExtractStridedMetadataOp>(loc, memref); |
| auto offset = runtimeMetadata.getConstifiedMixedOffset(); |
| auto strides = runtimeMetadata.getConstifiedMixedStrides(); |
| auto sizes = runtimeMetadata.getConstifiedMixedSizes(); |
| return computeLinearBounds(builder, loc, offset, strides, sizes); |
| } |
| |
| /// Verifies that the linear bounds of a reinterpret_cast op are within the |
| /// linear bounds of the base memref: low >= baseLow && high <= baseHigh |
| struct ReinterpretCastOpInterface |
| : public RuntimeVerifiableOpInterface::ExternalModel< |
| ReinterpretCastOpInterface, ReinterpretCastOp> { |
| void generateRuntimeVerification(Operation *op, OpBuilder &builder, |
| Location loc) const { |
| auto reinterpretCast = cast<ReinterpretCastOp>(op); |
| auto baseMemref = reinterpretCast.getSource(); |
| auto resultMemref = |
| cast<TypedValue<BaseMemRefType>>(reinterpretCast.getResult()); |
| |
| builder.setInsertionPointAfter(op); |
| |
| // Compute the linear bounds of the base memref |
| auto [baseLow, baseHigh] = computeLinearBounds(builder, loc, baseMemref); |
| |
| // Compute the linear bounds of the resulting memref |
| auto [low, high] = computeLinearBounds(builder, loc, resultMemref); |
| |
| // Check low >= baseLow |
| auto geLow = builder.createOrFold<arith::CmpIOp>( |
| loc, arith::CmpIPredicate::sge, low, baseLow); |
| |
| // Check high <= baseHigh |
| auto leHigh = builder.createOrFold<arith::CmpIOp>( |
| loc, arith::CmpIPredicate::sle, high, baseHigh); |
| |
| auto assertCond = builder.createOrFold<arith::AndIOp>(loc, geLow, leHigh); |
| |
| builder.create<cf::AssertOp>( |
| loc, assertCond, |
| RuntimeVerifiableOpInterface::generateErrorMessage( |
| op, |
| "result of reinterpret_cast is out-of-bounds of the base memref")); |
| } |
| }; |
| |
| /// Verifies that the linear bounds of a subview op are within the linear bounds |
| /// of the base memref: low >= baseLow && high <= baseHigh |
| /// TODO: This is not yet a full runtime verification of subview. For example, |
| /// consider: |
| /// %m = memref.alloc(%c10, %c10) : memref<10x10xf32> |
| /// memref.subview %m[%c0, %c0][%c20, %c2][%c1, %c1] |
| /// : memref<?x?xf32> to memref<?x?xf32> |
| /// The subview is in-bounds of the entire base memref but the first dimension |
| /// is out-of-bounds. Future work would verify the bounds on a per-dimension |
| /// basis. |
| struct SubViewOpInterface |
| : public RuntimeVerifiableOpInterface::ExternalModel<SubViewOpInterface, |
| SubViewOp> { |
| void generateRuntimeVerification(Operation *op, OpBuilder &builder, |
| Location loc) const { |
| auto subView = cast<SubViewOp>(op); |
| auto baseMemref = cast<TypedValue<BaseMemRefType>>(subView.getSource()); |
| auto resultMemref = cast<TypedValue<BaseMemRefType>>(subView.getResult()); |
| |
| builder.setInsertionPointAfter(op); |
| |
| // Compute the linear bounds of the base memref |
| auto [baseLow, baseHigh] = computeLinearBounds(builder, loc, baseMemref); |
| |
| // Compute the linear bounds of the resulting memref |
| auto [low, high] = computeLinearBounds(builder, loc, resultMemref); |
| |
| // Check low >= baseLow |
| auto geLow = builder.createOrFold<arith::CmpIOp>( |
| loc, arith::CmpIPredicate::sge, low, baseLow); |
| |
| // Check high <= baseHigh |
| auto leHigh = builder.createOrFold<arith::CmpIOp>( |
| loc, arith::CmpIPredicate::sle, high, baseHigh); |
| |
| auto assertCond = builder.createOrFold<arith::AndIOp>(loc, geLow, leHigh); |
| |
| builder.create<cf::AssertOp>( |
| loc, assertCond, |
| RuntimeVerifiableOpInterface::generateErrorMessage( |
| op, "subview is out-of-bounds of the base memref")); |
| } |
| }; |
| |
| struct ExpandShapeOpInterface |
| : public RuntimeVerifiableOpInterface::ExternalModel<ExpandShapeOpInterface, |
| ExpandShapeOp> { |
| void generateRuntimeVerification(Operation *op, OpBuilder &builder, |
| Location loc) const { |
| auto expandShapeOp = cast<ExpandShapeOp>(op); |
| |
| // Verify that the expanded dim sizes are a product of the collapsed dim |
| // size. |
| for (const auto &it : |
| llvm::enumerate(expandShapeOp.getReassociationIndices())) { |
| Value srcDimSz = |
| builder.create<DimOp>(loc, expandShapeOp.getSrc(), it.index()); |
| int64_t groupSz = 1; |
| bool foundDynamicDim = false; |
| for (int64_t resultDim : it.value()) { |
| if (expandShapeOp.getResultType().isDynamicDim(resultDim)) { |
| // Keep this assert here in case the op is extended in the future. |
| assert(!foundDynamicDim && |
| "more than one dynamic dim found in reassoc group"); |
| (void)foundDynamicDim; |
| foundDynamicDim = true; |
| continue; |
| } |
| groupSz *= expandShapeOp.getResultType().getDimSize(resultDim); |
| } |
| Value staticResultDimSz = |
| builder.create<arith::ConstantIndexOp>(loc, groupSz); |
| // staticResultDimSz must divide srcDimSz evenly. |
| Value mod = |
| builder.create<arith::RemSIOp>(loc, srcDimSz, staticResultDimSz); |
| Value isModZero = builder.create<arith::CmpIOp>( |
| loc, arith::CmpIPredicate::eq, mod, |
| builder.create<arith::ConstantIndexOp>(loc, 0)); |
| builder.create<cf::AssertOp>( |
| loc, isModZero, |
| RuntimeVerifiableOpInterface::generateErrorMessage( |
| op, "static result dims in reassoc group do not " |
| "divide src dim evenly")); |
| } |
| } |
| }; |
| } // namespace |
| } // namespace memref |
| } // namespace mlir |
| |
| void mlir::memref::registerRuntimeVerifiableOpInterfaceExternalModels( |
| DialectRegistry ®istry) { |
| registry.addExtension(+[](MLIRContext *ctx, memref::MemRefDialect *dialect) { |
| CastOp::attachInterface<CastOpInterface>(*ctx); |
| ExpandShapeOp::attachInterface<ExpandShapeOpInterface>(*ctx); |
| LoadOp::attachInterface<LoadStoreOpInterface<LoadOp>>(*ctx); |
| ReinterpretCastOp::attachInterface<ReinterpretCastOpInterface>(*ctx); |
| StoreOp::attachInterface<LoadStoreOpInterface<StoreOp>>(*ctx); |
| SubViewOp::attachInterface<SubViewOpInterface>(*ctx); |
| |
| // Load additional dialects of which ops may get created. |
| ctx->loadDialect<affine::AffineDialect, arith::ArithDialect, |
| cf::ControlFlowDialect>(); |
| }); |
| } |