| //===- ArithmeticToSPIRV.cpp - Arithmetic to SPIRV dialect conversion -----===// |
| // |
| // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| // See https://llvm.org/LICENSE.txt for license information. |
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| // |
| //===----------------------------------------------------------------------===// |
| |
| #include "mlir/Conversion/ArithmeticToSPIRV/ArithmeticToSPIRV.h" |
| #include "../PassDetail.h" |
| #include "../SPIRVCommon/Pattern.h" |
| #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" |
| #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" |
| #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" |
| #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h" |
| #include "llvm/Support/Debug.h" |
| |
| #define DEBUG_TYPE "arith-to-spirv-pattern" |
| |
| using namespace mlir; |
| |
| //===----------------------------------------------------------------------===// |
| // Operation Conversion |
| //===----------------------------------------------------------------------===// |
| |
| namespace { |
| |
| /// Converts composite arith.constant operation to spv.Constant. |
| struct ConstantCompositeOpPattern final |
| : public OpConversionPattern<arith::ConstantOp> { |
| using OpConversionPattern<arith::ConstantOp>::OpConversionPattern; |
| |
| LogicalResult |
| matchAndRewrite(arith::ConstantOp op, OpAdaptor adaptor, |
| ConversionPatternRewriter &rewriter) const override; |
| }; |
| |
| /// Converts scalar arith.constant operation to spv.Constant. |
| struct ConstantScalarOpPattern final |
| : public OpConversionPattern<arith::ConstantOp> { |
| using OpConversionPattern<arith::ConstantOp>::OpConversionPattern; |
| |
| LogicalResult |
| matchAndRewrite(arith::ConstantOp op, OpAdaptor adaptor, |
| ConversionPatternRewriter &rewriter) const override; |
| }; |
| |
| /// Converts arith.remsi to GLSL SPIR-V ops. |
| /// |
| /// This cannot be merged into the template unary/binary pattern due to Vulkan |
| /// restrictions over spv.SRem and spv.SMod. |
| struct RemSIOpGLSLPattern final : public OpConversionPattern<arith::RemSIOp> { |
| using OpConversionPattern<arith::RemSIOp>::OpConversionPattern; |
| |
| LogicalResult |
| matchAndRewrite(arith::RemSIOp op, OpAdaptor adaptor, |
| ConversionPatternRewriter &rewriter) const override; |
| }; |
| |
| /// Converts arith.remsi to OpenCL SPIR-V ops. |
| struct RemSIOpOCLPattern final : public OpConversionPattern<arith::RemSIOp> { |
| using OpConversionPattern<arith::RemSIOp>::OpConversionPattern; |
| |
| LogicalResult |
| matchAndRewrite(arith::RemSIOp op, OpAdaptor adaptor, |
| ConversionPatternRewriter &rewriter) const override; |
| }; |
| |
| /// Converts bitwise operations to SPIR-V operations. This is a special pattern |
| /// other than the BinaryOpPatternPattern because if the operands are boolean |
| /// values, SPIR-V uses different operations (`SPIRVLogicalOp`). For |
| /// non-boolean operands, SPIR-V should use `SPIRVBitwiseOp`. |
| template <typename Op, typename SPIRVLogicalOp, typename SPIRVBitwiseOp> |
| struct BitwiseOpPattern final : public OpConversionPattern<Op> { |
| using OpConversionPattern<Op>::OpConversionPattern; |
| |
| LogicalResult |
| matchAndRewrite(Op op, typename Op::Adaptor adaptor, |
| ConversionPatternRewriter &rewriter) const override; |
| }; |
| |
| /// Converts arith.xori to SPIR-V operations. |
| struct XOrIOpLogicalPattern final : public OpConversionPattern<arith::XOrIOp> { |
| using OpConversionPattern<arith::XOrIOp>::OpConversionPattern; |
| |
| LogicalResult |
| matchAndRewrite(arith::XOrIOp op, OpAdaptor adaptor, |
| ConversionPatternRewriter &rewriter) const override; |
| }; |
| |
| /// Converts arith.xori to SPIR-V operations if the type of source is i1 or |
| /// vector of i1. |
| struct XOrIOpBooleanPattern final : public OpConversionPattern<arith::XOrIOp> { |
| using OpConversionPattern<arith::XOrIOp>::OpConversionPattern; |
| |
| LogicalResult |
| matchAndRewrite(arith::XOrIOp op, OpAdaptor adaptor, |
| ConversionPatternRewriter &rewriter) const override; |
| }; |
| |
| /// Converts arith.uitofp to spv.Select if the type of source is i1 or vector of |
| /// i1. |
| struct UIToFPI1Pattern final : public OpConversionPattern<arith::UIToFPOp> { |
| using OpConversionPattern<arith::UIToFPOp>::OpConversionPattern; |
| |
| LogicalResult |
| matchAndRewrite(arith::UIToFPOp op, OpAdaptor adaptor, |
| ConversionPatternRewriter &rewriter) const override; |
| }; |
| |
| /// Converts arith.extui to spv.Select if the type of source is i1 or vector of |
| /// i1. |
| struct ExtUII1Pattern final : public OpConversionPattern<arith::ExtUIOp> { |
| using OpConversionPattern<arith::ExtUIOp>::OpConversionPattern; |
| |
| LogicalResult |
| matchAndRewrite(arith::ExtUIOp op, OpAdaptor adaptor, |
| ConversionPatternRewriter &rewriter) const override; |
| }; |
| |
| /// Converts arith.trunci to spv.Select if the type of result is i1 or vector of |
| /// i1. |
| struct TruncII1Pattern final : public OpConversionPattern<arith::TruncIOp> { |
| using OpConversionPattern<arith::TruncIOp>::OpConversionPattern; |
| |
| LogicalResult |
| matchAndRewrite(arith::TruncIOp op, OpAdaptor adaptor, |
| ConversionPatternRewriter &rewriter) const override; |
| }; |
| |
| /// Converts type-casting standard operations to SPIR-V operations. |
| template <typename Op, typename SPIRVOp> |
| struct TypeCastingOpPattern final : public OpConversionPattern<Op> { |
| using OpConversionPattern<Op>::OpConversionPattern; |
| |
| LogicalResult |
| matchAndRewrite(Op op, typename Op::Adaptor adaptor, |
| ConversionPatternRewriter &rewriter) const override; |
| }; |
| |
| /// Converts integer compare operation on i1 type operands to SPIR-V ops. |
| class CmpIOpBooleanPattern final : public OpConversionPattern<arith::CmpIOp> { |
| public: |
| using OpConversionPattern<arith::CmpIOp>::OpConversionPattern; |
| |
| LogicalResult |
| matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor, |
| ConversionPatternRewriter &rewriter) const override; |
| }; |
| |
| /// Converts integer compare operation to SPIR-V ops. |
| class CmpIOpPattern final : public OpConversionPattern<arith::CmpIOp> { |
| public: |
| using OpConversionPattern<arith::CmpIOp>::OpConversionPattern; |
| |
| LogicalResult |
| matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor, |
| ConversionPatternRewriter &rewriter) const override; |
| }; |
| |
| /// Converts floating-point comparison operations to SPIR-V ops. |
| class CmpFOpPattern final : public OpConversionPattern<arith::CmpFOp> { |
| public: |
| using OpConversionPattern<arith::CmpFOp>::OpConversionPattern; |
| |
| LogicalResult |
| matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor, |
| ConversionPatternRewriter &rewriter) const override; |
| }; |
| |
| /// Converts floating point NaN check to SPIR-V ops. This pattern requires |
| /// Kernel capability. |
| class CmpFOpNanKernelPattern final : public OpConversionPattern<arith::CmpFOp> { |
| public: |
| using OpConversionPattern<arith::CmpFOp>::OpConversionPattern; |
| |
| LogicalResult |
| matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor, |
| ConversionPatternRewriter &rewriter) const override; |
| }; |
| |
| /// Converts floating point NaN check to SPIR-V ops. This pattern does not |
| /// require additional capability. |
| class CmpFOpNanNonePattern final : public OpConversionPattern<arith::CmpFOp> { |
| public: |
| using OpConversionPattern<arith::CmpFOp>::OpConversionPattern; |
| |
| LogicalResult |
| matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor, |
| ConversionPatternRewriter &rewriter) const override; |
| }; |
| |
| } // end anonymous namespace |
| |
| //===----------------------------------------------------------------------===// |
| // Conversion Helpers |
| //===----------------------------------------------------------------------===// |
| |
| /// Converts the given `srcAttr` into a boolean attribute if it holds an |
| /// integral value. Returns null attribute if conversion fails. |
| static BoolAttr convertBoolAttr(Attribute srcAttr, Builder builder) { |
| if (auto boolAttr = srcAttr.dyn_cast<BoolAttr>()) |
| return boolAttr; |
| if (auto intAttr = srcAttr.dyn_cast<IntegerAttr>()) |
| return builder.getBoolAttr(intAttr.getValue().getBoolValue()); |
| return BoolAttr(); |
| } |
| |
| /// Converts the given `srcAttr` to a new attribute of the given `dstType`. |
| /// Returns null attribute if conversion fails. |
| static IntegerAttr convertIntegerAttr(IntegerAttr srcAttr, IntegerType dstType, |
| Builder builder) { |
| // If the source number uses less active bits than the target bitwidth, then |
| // it should be safe to convert. |
| if (srcAttr.getValue().isIntN(dstType.getWidth())) |
| return builder.getIntegerAttr(dstType, srcAttr.getInt()); |
| |
| // XXX: Try again by interpreting the source number as a signed value. |
| // Although integers in the standard dialect are signless, they can represent |
| // a signed number. It's the operation decides how to interpret. This is |
| // dangerous, but it seems there is no good way of handling this if we still |
| // want to change the bitwidth. Emit a message at least. |
| if (srcAttr.getValue().isSignedIntN(dstType.getWidth())) { |
| auto dstAttr = builder.getIntegerAttr(dstType, srcAttr.getInt()); |
| LLVM_DEBUG(llvm::dbgs() << "attribute '" << srcAttr << "' converted to '" |
| << dstAttr << "' for type '" << dstType << "'\n"); |
| return dstAttr; |
| } |
| |
| LLVM_DEBUG(llvm::dbgs() << "attribute '" << srcAttr |
| << "' illegal: cannot fit into target type '" |
| << dstType << "'\n"); |
| return IntegerAttr(); |
| } |
| |
| /// Converts the given `srcAttr` to a new attribute of the given `dstType`. |
| /// Returns null attribute if `dstType` is not 32-bit or conversion fails. |
| static FloatAttr convertFloatAttr(FloatAttr srcAttr, FloatType dstType, |
| Builder builder) { |
| // Only support converting to float for now. |
| if (!dstType.isF32()) |
| return FloatAttr(); |
| |
| // Try to convert the source floating-point number to single precision. |
| APFloat dstVal = srcAttr.getValue(); |
| bool losesInfo = false; |
| APFloat::opStatus status = |
| dstVal.convert(APFloat::IEEEsingle(), APFloat::rmTowardZero, &losesInfo); |
| if (status != APFloat::opOK || losesInfo) { |
| LLVM_DEBUG(llvm::dbgs() |
| << srcAttr << " illegal: cannot fit into converted type '" |
| << dstType << "'\n"); |
| return FloatAttr(); |
| } |
| |
| return builder.getF32FloatAttr(dstVal.convertToFloat()); |
| } |
| |
| /// Returns true if the given `type` is a boolean scalar or vector type. |
| static bool isBoolScalarOrVector(Type type) { |
| if (type.isInteger(1)) |
| return true; |
| if (auto vecType = type.dyn_cast<VectorType>()) |
| return vecType.getElementType().isInteger(1); |
| return false; |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // ConstantOp with composite type |
| //===----------------------------------------------------------------------===// |
| |
| LogicalResult ConstantCompositeOpPattern::matchAndRewrite( |
| arith::ConstantOp constOp, OpAdaptor adaptor, |
| ConversionPatternRewriter &rewriter) const { |
| auto srcType = constOp.getType().dyn_cast<ShapedType>(); |
| if (!srcType) |
| return failure(); |
| |
| // arith.constant should only have vector or tenor types. |
| assert((srcType.isa<VectorType, RankedTensorType>())); |
| |
| auto dstType = getTypeConverter()->convertType(srcType); |
| if (!dstType) |
| return failure(); |
| |
| auto dstElementsAttr = constOp.getValue().dyn_cast<DenseElementsAttr>(); |
| ShapedType dstAttrType = dstElementsAttr.getType(); |
| if (!dstElementsAttr) |
| return failure(); |
| |
| // If the composite type has more than one dimensions, perform linearization. |
| if (srcType.getRank() > 1) { |
| if (srcType.isa<RankedTensorType>()) { |
| dstAttrType = RankedTensorType::get(srcType.getNumElements(), |
| srcType.getElementType()); |
| dstElementsAttr = dstElementsAttr.reshape(dstAttrType); |
| } else { |
| // TODO: add support for large vectors. |
| return failure(); |
| } |
| } |
| |
| Type srcElemType = srcType.getElementType(); |
| Type dstElemType; |
| // Tensor types are converted to SPIR-V array types; vector types are |
| // converted to SPIR-V vector/array types. |
| if (auto arrayType = dstType.dyn_cast<spirv::ArrayType>()) |
| dstElemType = arrayType.getElementType(); |
| else |
| dstElemType = dstType.cast<VectorType>().getElementType(); |
| |
| // If the source and destination element types are different, perform |
| // attribute conversion. |
| if (srcElemType != dstElemType) { |
| SmallVector<Attribute, 8> elements; |
| if (srcElemType.isa<FloatType>()) { |
| for (FloatAttr srcAttr : dstElementsAttr.getValues<FloatAttr>()) { |
| FloatAttr dstAttr = |
| convertFloatAttr(srcAttr, dstElemType.cast<FloatType>(), rewriter); |
| if (!dstAttr) |
| return failure(); |
| elements.push_back(dstAttr); |
| } |
| } else if (srcElemType.isInteger(1)) { |
| return failure(); |
| } else { |
| for (IntegerAttr srcAttr : dstElementsAttr.getValues<IntegerAttr>()) { |
| IntegerAttr dstAttr = convertIntegerAttr( |
| srcAttr, dstElemType.cast<IntegerType>(), rewriter); |
| if (!dstAttr) |
| return failure(); |
| elements.push_back(dstAttr); |
| } |
| } |
| |
| // Unfortunately, we cannot use dialect-specific types for element |
| // attributes; element attributes only works with builtin types. So we need |
| // to prepare another converted builtin types for the destination elements |
| // attribute. |
| if (dstAttrType.isa<RankedTensorType>()) |
| dstAttrType = RankedTensorType::get(dstAttrType.getShape(), dstElemType); |
| else |
| dstAttrType = VectorType::get(dstAttrType.getShape(), dstElemType); |
| |
| dstElementsAttr = DenseElementsAttr::get(dstAttrType, elements); |
| } |
| |
| rewriter.replaceOpWithNewOp<spirv::ConstantOp>(constOp, dstType, |
| dstElementsAttr); |
| return success(); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // ConstantOp with scalar type |
| //===----------------------------------------------------------------------===// |
| |
| LogicalResult ConstantScalarOpPattern::matchAndRewrite( |
| arith::ConstantOp constOp, OpAdaptor adaptor, |
| ConversionPatternRewriter &rewriter) const { |
| Type srcType = constOp.getType(); |
| if (!srcType.isIntOrIndexOrFloat()) |
| return failure(); |
| |
| Type dstType = getTypeConverter()->convertType(srcType); |
| if (!dstType) |
| return failure(); |
| |
| // Floating-point types. |
| if (srcType.isa<FloatType>()) { |
| auto srcAttr = constOp.getValue().cast<FloatAttr>(); |
| auto dstAttr = srcAttr; |
| |
| // Floating-point types not supported in the target environment are all |
| // converted to float type. |
| if (srcType != dstType) { |
| dstAttr = convertFloatAttr(srcAttr, dstType.cast<FloatType>(), rewriter); |
| if (!dstAttr) |
| return failure(); |
| } |
| |
| rewriter.replaceOpWithNewOp<spirv::ConstantOp>(constOp, dstType, dstAttr); |
| return success(); |
| } |
| |
| // Bool type. |
| if (srcType.isInteger(1)) { |
| // arith.constant can use 0/1 instead of true/false for i1 values. We need |
| // to handle that here. |
| auto dstAttr = convertBoolAttr(constOp.getValue(), rewriter); |
| if (!dstAttr) |
| return failure(); |
| rewriter.replaceOpWithNewOp<spirv::ConstantOp>(constOp, dstType, dstAttr); |
| return success(); |
| } |
| |
| // IndexType or IntegerType. Index values are converted to 32-bit integer |
| // values when converting to SPIR-V. |
| auto srcAttr = constOp.getValue().cast<IntegerAttr>(); |
| auto dstAttr = |
| convertIntegerAttr(srcAttr, dstType.cast<IntegerType>(), rewriter); |
| if (!dstAttr) |
| return failure(); |
| rewriter.replaceOpWithNewOp<spirv::ConstantOp>(constOp, dstType, dstAttr); |
| return success(); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // RemSIOpGLSLPattern |
| //===----------------------------------------------------------------------===// |
| |
| /// Returns signed remainder for `lhs` and `rhs` and lets the result follow |
| /// the sign of `signOperand`. |
| /// |
| /// Note that this is needed for Vulkan. Per the Vulkan's SPIR-V environment |
| /// spec, "for the OpSRem and OpSMod instructions, if either operand is negative |
| /// the result is undefined." So we cannot directly use spv.SRem/spv.SMod |
| /// if either operand can be negative. Emulate it via spv.UMod. |
| template <typename SignedAbsOp> |
| static Value emulateSignedRemainder(Location loc, Value lhs, Value rhs, |
| Value signOperand, OpBuilder &builder) { |
| assert(lhs.getType() == rhs.getType()); |
| assert(lhs == signOperand || rhs == signOperand); |
| |
| Type type = lhs.getType(); |
| |
| // Calculate the remainder with spv.UMod. |
| Value lhsAbs = builder.create<SignedAbsOp>(loc, type, lhs); |
| Value rhsAbs = builder.create<SignedAbsOp>(loc, type, rhs); |
| Value abs = builder.create<spirv::UModOp>(loc, lhsAbs, rhsAbs); |
| |
| // Fix the sign. |
| Value isPositive; |
| if (lhs == signOperand) |
| isPositive = builder.create<spirv::IEqualOp>(loc, lhs, lhsAbs); |
| else |
| isPositive = builder.create<spirv::IEqualOp>(loc, rhs, rhsAbs); |
| Value absNegate = builder.create<spirv::SNegateOp>(loc, type, abs); |
| return builder.create<spirv::SelectOp>(loc, type, isPositive, abs, absNegate); |
| } |
| |
| LogicalResult |
| RemSIOpGLSLPattern::matchAndRewrite(arith::RemSIOp op, OpAdaptor adaptor, |
| ConversionPatternRewriter &rewriter) const { |
| Value result = emulateSignedRemainder<spirv::GLSLSAbsOp>( |
| op.getLoc(), adaptor.getOperands()[0], adaptor.getOperands()[1], |
| adaptor.getOperands()[0], rewriter); |
| rewriter.replaceOp(op, result); |
| |
| return success(); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // RemSIOpOCLPattern |
| //===----------------------------------------------------------------------===// |
| |
| LogicalResult |
| RemSIOpOCLPattern::matchAndRewrite(arith::RemSIOp op, OpAdaptor adaptor, |
| ConversionPatternRewriter &rewriter) const { |
| Value result = emulateSignedRemainder<spirv::OCLSAbsOp>( |
| op.getLoc(), adaptor.getOperands()[0], adaptor.getOperands()[1], |
| adaptor.getOperands()[0], rewriter); |
| rewriter.replaceOp(op, result); |
| |
| return success(); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // BitwiseOpPattern |
| //===----------------------------------------------------------------------===// |
| |
| template <typename Op, typename SPIRVLogicalOp, typename SPIRVBitwiseOp> |
| LogicalResult |
| BitwiseOpPattern<Op, SPIRVLogicalOp, SPIRVBitwiseOp>::matchAndRewrite( |
| Op op, typename Op::Adaptor adaptor, |
| ConversionPatternRewriter &rewriter) const { |
| assert(adaptor.getOperands().size() == 2); |
| auto dstType = |
| this->getTypeConverter()->convertType(op.getResult().getType()); |
| if (!dstType) |
| return failure(); |
| if (isBoolScalarOrVector(adaptor.getOperands().front().getType())) { |
| rewriter.template replaceOpWithNewOp<SPIRVLogicalOp>(op, dstType, |
| adaptor.getOperands()); |
| } else { |
| rewriter.template replaceOpWithNewOp<SPIRVBitwiseOp>(op, dstType, |
| adaptor.getOperands()); |
| } |
| return success(); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // XOrIOpLogicalPattern |
| //===----------------------------------------------------------------------===// |
| |
| LogicalResult XOrIOpLogicalPattern::matchAndRewrite( |
| arith::XOrIOp op, OpAdaptor adaptor, |
| ConversionPatternRewriter &rewriter) const { |
| assert(adaptor.getOperands().size() == 2); |
| |
| if (isBoolScalarOrVector(adaptor.getOperands().front().getType())) |
| return failure(); |
| |
| auto dstType = getTypeConverter()->convertType(op.getType()); |
| if (!dstType) |
| return failure(); |
| rewriter.replaceOpWithNewOp<spirv::BitwiseXorOp>(op, dstType, |
| adaptor.getOperands()); |
| |
| return success(); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // XOrIOpBooleanPattern |
| //===----------------------------------------------------------------------===// |
| |
| LogicalResult XOrIOpBooleanPattern::matchAndRewrite( |
| arith::XOrIOp op, OpAdaptor adaptor, |
| ConversionPatternRewriter &rewriter) const { |
| assert(adaptor.getOperands().size() == 2); |
| |
| if (!isBoolScalarOrVector(adaptor.getOperands().front().getType())) |
| return failure(); |
| |
| auto dstType = getTypeConverter()->convertType(op.getType()); |
| if (!dstType) |
| return failure(); |
| rewriter.replaceOpWithNewOp<spirv::LogicalNotEqualOp>(op, dstType, |
| adaptor.getOperands()); |
| return success(); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // UIToFPI1Pattern |
| //===----------------------------------------------------------------------===// |
| |
| LogicalResult |
| UIToFPI1Pattern::matchAndRewrite(arith::UIToFPOp op, OpAdaptor adaptor, |
| ConversionPatternRewriter &rewriter) const { |
| auto srcType = adaptor.getOperands().front().getType(); |
| if (!isBoolScalarOrVector(srcType)) |
| return failure(); |
| |
| auto dstType = |
| this->getTypeConverter()->convertType(op.getResult().getType()); |
| Location loc = op.getLoc(); |
| Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter); |
| Value one = spirv::ConstantOp::getOne(dstType, loc, rewriter); |
| rewriter.template replaceOpWithNewOp<spirv::SelectOp>( |
| op, dstType, adaptor.getOperands().front(), one, zero); |
| return success(); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // ExtUII1Pattern |
| //===----------------------------------------------------------------------===// |
| |
| LogicalResult |
| ExtUII1Pattern::matchAndRewrite(arith::ExtUIOp op, OpAdaptor adaptor, |
| ConversionPatternRewriter &rewriter) const { |
| auto srcType = adaptor.getOperands().front().getType(); |
| if (!isBoolScalarOrVector(srcType)) |
| return failure(); |
| |
| auto dstType = |
| this->getTypeConverter()->convertType(op.getResult().getType()); |
| Location loc = op.getLoc(); |
| Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter); |
| Value one = spirv::ConstantOp::getOne(dstType, loc, rewriter); |
| rewriter.template replaceOpWithNewOp<spirv::SelectOp>( |
| op, dstType, adaptor.getOperands().front(), one, zero); |
| return success(); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // TruncII1Pattern |
| //===----------------------------------------------------------------------===// |
| |
| LogicalResult |
| TruncII1Pattern::matchAndRewrite(arith::TruncIOp op, OpAdaptor adaptor, |
| ConversionPatternRewriter &rewriter) const { |
| auto dstType = |
| this->getTypeConverter()->convertType(op.getResult().getType()); |
| if (!isBoolScalarOrVector(dstType)) |
| return failure(); |
| |
| Location loc = op.getLoc(); |
| auto srcType = adaptor.getOperands().front().getType(); |
| // Check if (x & 1) == 1. |
| Value mask = spirv::ConstantOp::getOne(srcType, loc, rewriter); |
| Value maskedSrc = rewriter.create<spirv::BitwiseAndOp>( |
| loc, srcType, adaptor.getOperands()[0], mask); |
| Value isOne = rewriter.create<spirv::IEqualOp>(loc, maskedSrc, mask); |
| |
| Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter); |
| Value one = spirv::ConstantOp::getOne(dstType, loc, rewriter); |
| rewriter.replaceOpWithNewOp<spirv::SelectOp>(op, dstType, isOne, one, zero); |
| return success(); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // TypeCastingOpPattern |
| //===----------------------------------------------------------------------===// |
| |
| template <typename Op, typename SPIRVOp> |
| LogicalResult TypeCastingOpPattern<Op, SPIRVOp>::matchAndRewrite( |
| Op op, typename Op::Adaptor adaptor, |
| ConversionPatternRewriter &rewriter) const { |
| assert(adaptor.getOperands().size() == 1); |
| auto srcType = adaptor.getOperands().front().getType(); |
| auto dstType = |
| this->getTypeConverter()->convertType(op.getResult().getType()); |
| if (isBoolScalarOrVector(srcType) || isBoolScalarOrVector(dstType)) |
| return failure(); |
| if (dstType == srcType) { |
| // Due to type conversion, we are seeing the same source and target type. |
| // Then we can just erase this operation by forwarding its operand. |
| rewriter.replaceOp(op, adaptor.getOperands().front()); |
| } else { |
| rewriter.template replaceOpWithNewOp<SPIRVOp>(op, dstType, |
| adaptor.getOperands()); |
| } |
| return success(); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // CmpIOpBooleanPattern |
| //===----------------------------------------------------------------------===// |
| |
| LogicalResult CmpIOpBooleanPattern::matchAndRewrite( |
| arith::CmpIOp op, OpAdaptor adaptor, |
| ConversionPatternRewriter &rewriter) const { |
| Type operandType = op.getLhs().getType(); |
| if (!isBoolScalarOrVector(operandType)) |
| return failure(); |
| |
| switch (op.getPredicate()) { |
| #define DISPATCH(cmpPredicate, spirvOp) \ |
| case cmpPredicate: \ |
| rewriter.replaceOpWithNewOp<spirvOp>(op, op.getResult().getType(), \ |
| adaptor.lhs(), adaptor.rhs()); \ |
| return success(); |
| |
| DISPATCH(arith::CmpIPredicate::eq, spirv::LogicalEqualOp); |
| DISPATCH(arith::CmpIPredicate::ne, spirv::LogicalNotEqualOp); |
| |
| #undef DISPATCH |
| default:; |
| } |
| return failure(); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // CmpIOpPattern |
| //===----------------------------------------------------------------------===// |
| |
| LogicalResult |
| CmpIOpPattern::matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor, |
| ConversionPatternRewriter &rewriter) const { |
| Type operandType = op.getLhs().getType(); |
| if (isBoolScalarOrVector(operandType)) |
| return failure(); |
| |
| switch (op.getPredicate()) { |
| #define DISPATCH(cmpPredicate, spirvOp) \ |
| case cmpPredicate: \ |
| if (spirvOp::template hasTrait<OpTrait::spirv::UnsignedOp>() && \ |
| operandType != this->getTypeConverter()->convertType(operandType)) { \ |
| return op.emitError( \ |
| "bitwidth emulation is not implemented yet on unsigned op"); \ |
| } \ |
| rewriter.replaceOpWithNewOp<spirvOp>(op, op.getResult().getType(), \ |
| adaptor.lhs(), adaptor.rhs()); \ |
| return success(); |
| |
| DISPATCH(arith::CmpIPredicate::eq, spirv::IEqualOp); |
| DISPATCH(arith::CmpIPredicate::ne, spirv::INotEqualOp); |
| DISPATCH(arith::CmpIPredicate::slt, spirv::SLessThanOp); |
| DISPATCH(arith::CmpIPredicate::sle, spirv::SLessThanEqualOp); |
| DISPATCH(arith::CmpIPredicate::sgt, spirv::SGreaterThanOp); |
| DISPATCH(arith::CmpIPredicate::sge, spirv::SGreaterThanEqualOp); |
| DISPATCH(arith::CmpIPredicate::ult, spirv::ULessThanOp); |
| DISPATCH(arith::CmpIPredicate::ule, spirv::ULessThanEqualOp); |
| DISPATCH(arith::CmpIPredicate::ugt, spirv::UGreaterThanOp); |
| DISPATCH(arith::CmpIPredicate::uge, spirv::UGreaterThanEqualOp); |
| |
| #undef DISPATCH |
| } |
| return failure(); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // CmpFOpPattern |
| //===----------------------------------------------------------------------===// |
| |
| LogicalResult |
| CmpFOpPattern::matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor, |
| ConversionPatternRewriter &rewriter) const { |
| switch (op.getPredicate()) { |
| #define DISPATCH(cmpPredicate, spirvOp) \ |
| case cmpPredicate: \ |
| rewriter.replaceOpWithNewOp<spirvOp>(op, op.getResult().getType(), \ |
| adaptor.lhs(), adaptor.rhs()); \ |
| return success(); |
| |
| // Ordered. |
| DISPATCH(arith::CmpFPredicate::OEQ, spirv::FOrdEqualOp); |
| DISPATCH(arith::CmpFPredicate::OGT, spirv::FOrdGreaterThanOp); |
| DISPATCH(arith::CmpFPredicate::OGE, spirv::FOrdGreaterThanEqualOp); |
| DISPATCH(arith::CmpFPredicate::OLT, spirv::FOrdLessThanOp); |
| DISPATCH(arith::CmpFPredicate::OLE, spirv::FOrdLessThanEqualOp); |
| DISPATCH(arith::CmpFPredicate::ONE, spirv::FOrdNotEqualOp); |
| // Unordered. |
| DISPATCH(arith::CmpFPredicate::UEQ, spirv::FUnordEqualOp); |
| DISPATCH(arith::CmpFPredicate::UGT, spirv::FUnordGreaterThanOp); |
| DISPATCH(arith::CmpFPredicate::UGE, spirv::FUnordGreaterThanEqualOp); |
| DISPATCH(arith::CmpFPredicate::ULT, spirv::FUnordLessThanOp); |
| DISPATCH(arith::CmpFPredicate::ULE, spirv::FUnordLessThanEqualOp); |
| DISPATCH(arith::CmpFPredicate::UNE, spirv::FUnordNotEqualOp); |
| |
| #undef DISPATCH |
| |
| default: |
| break; |
| } |
| return failure(); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // CmpFOpNanKernelPattern |
| //===----------------------------------------------------------------------===// |
| |
| LogicalResult CmpFOpNanKernelPattern::matchAndRewrite( |
| arith::CmpFOp op, OpAdaptor adaptor, |
| ConversionPatternRewriter &rewriter) const { |
| if (op.getPredicate() == arith::CmpFPredicate::ORD) { |
| rewriter.replaceOpWithNewOp<spirv::OrderedOp>(op, adaptor.getLhs(), |
| adaptor.getRhs()); |
| return success(); |
| } |
| |
| if (op.getPredicate() == arith::CmpFPredicate::UNO) { |
| rewriter.replaceOpWithNewOp<spirv::UnorderedOp>(op, adaptor.getLhs(), |
| adaptor.getRhs()); |
| return success(); |
| } |
| |
| return failure(); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // CmpFOpNanNonePattern |
| //===----------------------------------------------------------------------===// |
| |
| LogicalResult CmpFOpNanNonePattern::matchAndRewrite( |
| arith::CmpFOp op, OpAdaptor adaptor, |
| ConversionPatternRewriter &rewriter) const { |
| if (op.getPredicate() != arith::CmpFPredicate::ORD && |
| op.getPredicate() != arith::CmpFPredicate::UNO) |
| return failure(); |
| |
| Location loc = op.getLoc(); |
| |
| Value lhsIsNan = rewriter.create<spirv::IsNanOp>(loc, adaptor.getLhs()); |
| Value rhsIsNan = rewriter.create<spirv::IsNanOp>(loc, adaptor.getRhs()); |
| |
| Value replace = rewriter.create<spirv::LogicalOrOp>(loc, lhsIsNan, rhsIsNan); |
| if (op.getPredicate() == arith::CmpFPredicate::ORD) |
| replace = rewriter.create<spirv::LogicalNotOp>(loc, replace); |
| |
| rewriter.replaceOp(op, replace); |
| return success(); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // Pattern Population |
| //===----------------------------------------------------------------------===// |
| |
| void mlir::arith::populateArithmeticToSPIRVPatterns( |
| SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns) { |
| // clang-format off |
| patterns.add< |
| ConstantCompositeOpPattern, |
| ConstantScalarOpPattern, |
| spirv::UnaryAndBinaryOpPattern<arith::AddIOp, spirv::IAddOp>, |
| spirv::UnaryAndBinaryOpPattern<arith::SubIOp, spirv::ISubOp>, |
| spirv::UnaryAndBinaryOpPattern<arith::MulIOp, spirv::IMulOp>, |
| spirv::UnaryAndBinaryOpPattern<arith::DivUIOp, spirv::UDivOp>, |
| spirv::UnaryAndBinaryOpPattern<arith::DivSIOp, spirv::SDivOp>, |
| spirv::UnaryAndBinaryOpPattern<arith::RemUIOp, spirv::UModOp>, |
| RemSIOpGLSLPattern, RemSIOpOCLPattern, |
| BitwiseOpPattern<arith::AndIOp, spirv::LogicalAndOp, spirv::BitwiseAndOp>, |
| BitwiseOpPattern<arith::OrIOp, spirv::LogicalOrOp, spirv::BitwiseOrOp>, |
| XOrIOpLogicalPattern, XOrIOpBooleanPattern, |
| spirv::UnaryAndBinaryOpPattern<arith::ShLIOp, spirv::ShiftLeftLogicalOp>, |
| spirv::UnaryAndBinaryOpPattern<arith::ShRUIOp, spirv::ShiftRightLogicalOp>, |
| spirv::UnaryAndBinaryOpPattern<arith::ShRSIOp, spirv::ShiftRightArithmeticOp>, |
| spirv::UnaryAndBinaryOpPattern<arith::NegFOp, spirv::FNegateOp>, |
| spirv::UnaryAndBinaryOpPattern<arith::AddFOp, spirv::FAddOp>, |
| spirv::UnaryAndBinaryOpPattern<arith::SubFOp, spirv::FSubOp>, |
| spirv::UnaryAndBinaryOpPattern<arith::MulFOp, spirv::FMulOp>, |
| spirv::UnaryAndBinaryOpPattern<arith::DivFOp, spirv::FDivOp>, |
| spirv::UnaryAndBinaryOpPattern<arith::RemFOp, spirv::FRemOp>, |
| TypeCastingOpPattern<arith::ExtUIOp, spirv::UConvertOp>, ExtUII1Pattern, |
| TypeCastingOpPattern<arith::ExtSIOp, spirv::SConvertOp>, |
| TypeCastingOpPattern<arith::ExtFOp, spirv::FConvertOp>, |
| TypeCastingOpPattern<arith::TruncIOp, spirv::SConvertOp>, TruncII1Pattern, |
| TypeCastingOpPattern<arith::TruncFOp, spirv::FConvertOp>, |
| TypeCastingOpPattern<arith::UIToFPOp, spirv::ConvertUToFOp>, UIToFPI1Pattern, |
| TypeCastingOpPattern<arith::SIToFPOp, spirv::ConvertSToFOp>, |
| TypeCastingOpPattern<arith::FPToSIOp, spirv::ConvertFToSOp>, |
| TypeCastingOpPattern<arith::IndexCastOp, spirv::SConvertOp>, |
| TypeCastingOpPattern<arith::BitcastOp, spirv::BitcastOp>, |
| CmpIOpBooleanPattern, CmpIOpPattern, |
| CmpFOpNanNonePattern, CmpFOpPattern |
| >(typeConverter, patterns.getContext()); |
| // clang-format on |
| |
| // Give CmpFOpNanKernelPattern a higher benefit so it can prevail when Kernel |
| // capability is available. |
| patterns.add<CmpFOpNanKernelPattern>(typeConverter, patterns.getContext(), |
| /*benefit=*/2); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // Pass Definition |
| //===----------------------------------------------------------------------===// |
| |
| namespace { |
| struct ConvertArithmeticToSPIRVPass |
| : public ConvertArithmeticToSPIRVBase<ConvertArithmeticToSPIRVPass> { |
| void runOnFunction() override { |
| auto module = getOperation()->getParentOfType<ModuleOp>(); |
| auto targetAttr = spirv::lookupTargetEnvOrDefault(module); |
| auto target = SPIRVConversionTarget::get(targetAttr); |
| |
| SPIRVTypeConverter::Options options; |
| options.emulateNon32BitScalarTypes = this->emulateNon32BitScalarTypes; |
| SPIRVTypeConverter typeConverter(targetAttr, options); |
| |
| RewritePatternSet patterns(&getContext()); |
| mlir::arith::populateArithmeticToSPIRVPatterns(typeConverter, patterns); |
| |
| if (failed(applyPartialConversion(getOperation(), *target, |
| std::move(patterns)))) |
| signalPassFailure(); |
| } |
| }; |
| } // end anonymous namespace |
| |
| std::unique_ptr<Pass> mlir::arith::createConvertArithmeticToSPIRVPass() { |
| return std::make_unique<ConvertArithmeticToSPIRVPass>(); |
| } |