| //===- ArithToSPIRV.cpp - Arithmetic to SPIRV dialect conversion -----===// |
| // |
| // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| // See https://llvm.org/LICENSE.txt for license information. |
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| // |
| //===----------------------------------------------------------------------===// |
| |
| #include "mlir/Conversion/ArithToSPIRV/ArithToSPIRV.h" |
| |
| #include "../SPIRVCommon/Pattern.h" |
| #include "mlir/Dialect/Arith/IR/Arith.h" |
| #include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h" |
| #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" |
| #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" |
| #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h" |
| #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h" |
| #include "mlir/IR/BuiltinAttributes.h" |
| #include "mlir/IR/BuiltinTypes.h" |
| #include "mlir/IR/DialectResourceBlobManager.h" |
| #include "llvm/ADT/APInt.h" |
| #include "llvm/ADT/ArrayRef.h" |
| #include "llvm/ADT/STLExtras.h" |
| #include "llvm/Support/Debug.h" |
| #include "llvm/Support/MathExtras.h" |
| #include <cassert> |
| #include <memory> |
| |
| namespace mlir { |
| #define GEN_PASS_DEF_CONVERTARITHTOSPIRVPASS |
| #include "mlir/Conversion/Passes.h.inc" |
| } // namespace mlir |
| |
| #define DEBUG_TYPE "arith-to-spirv-pattern" |
| |
| using namespace mlir; |
| |
| //===----------------------------------------------------------------------===// |
| // Conversion Helpers |
| //===----------------------------------------------------------------------===// |
| |
| /// Converts the given `srcAttr` into a boolean attribute if it holds an |
| /// integral value. Returns null attribute if conversion fails. |
| static BoolAttr convertBoolAttr(Attribute srcAttr, Builder builder) { |
| if (auto boolAttr = dyn_cast<BoolAttr>(srcAttr)) |
| return boolAttr; |
| if (auto intAttr = dyn_cast<IntegerAttr>(srcAttr)) |
| return builder.getBoolAttr(intAttr.getValue().getBoolValue()); |
| return {}; |
| } |
| |
| /// Converts the given `srcAttr` to a new attribute of the given `dstType`. |
| /// Returns null attribute if conversion fails. |
| static IntegerAttr convertIntegerAttr(IntegerAttr srcAttr, IntegerType dstType, |
| Builder builder) { |
| // If the source number uses less active bits than the target bitwidth, then |
| // it should be safe to convert. |
| if (srcAttr.getValue().isIntN(dstType.getWidth())) |
| return builder.getIntegerAttr(dstType, srcAttr.getInt()); |
| |
| // XXX: Try again by interpreting the source number as a signed value. |
| // Although integers in the standard dialect are signless, they can represent |
| // a signed number. It's the operation decides how to interpret. This is |
| // dangerous, but it seems there is no good way of handling this if we still |
| // want to change the bitwidth. Emit a message at least. |
| if (srcAttr.getValue().isSignedIntN(dstType.getWidth())) { |
| auto dstAttr = builder.getIntegerAttr(dstType, srcAttr.getInt()); |
| LLVM_DEBUG(llvm::dbgs() << "attribute '" << srcAttr << "' converted to '" |
| << dstAttr << "' for type '" << dstType << "'\n"); |
| return dstAttr; |
| } |
| |
| LLVM_DEBUG(llvm::dbgs() << "attribute '" << srcAttr |
| << "' illegal: cannot fit into target type '" |
| << dstType << "'\n"); |
| return {}; |
| } |
| |
| /// Converts the given `srcAttr` to a new attribute of the given `dstType`. |
| /// Returns null attribute if `dstType` is not 32-bit or conversion fails. |
| static FloatAttr convertFloatAttr(FloatAttr srcAttr, FloatType dstType, |
| Builder builder) { |
| // Only support converting to float for now. |
| if (!dstType.isF32()) |
| return FloatAttr(); |
| |
| // Try to convert the source floating-point number to single precision. |
| APFloat dstVal = srcAttr.getValue(); |
| bool losesInfo = false; |
| APFloat::opStatus status = |
| dstVal.convert(APFloat::IEEEsingle(), APFloat::rmTowardZero, &losesInfo); |
| if (status != APFloat::opOK || losesInfo) { |
| LLVM_DEBUG(llvm::dbgs() |
| << srcAttr << " illegal: cannot fit into converted type '" |
| << dstType << "'\n"); |
| return FloatAttr(); |
| } |
| |
| return builder.getF32FloatAttr(dstVal.convertToFloat()); |
| } |
| |
| // Get in IntegerAttr from FloatAttr while preserving the bits. |
| // Useful for converting float constants to integer constants while preserving |
| // the bits. |
| static IntegerAttr |
| getIntegerAttrFromFloatAttr(FloatAttr floatAttr, Type dstType, |
| ConversionPatternRewriter &rewriter) { |
| APFloat floatVal = floatAttr.getValue(); |
| APInt intVal = floatVal.bitcastToAPInt(); |
| return rewriter.getIntegerAttr(dstType, intVal); |
| } |
| |
| /// Returns true if the given `type` is a boolean scalar or vector type. |
| static bool isBoolScalarOrVector(Type type) { |
| assert(type && "Not a valid type"); |
| if (type.isInteger(1)) |
| return true; |
| |
| if (auto vecType = dyn_cast<VectorType>(type)) |
| return vecType.getElementType().isInteger(1); |
| |
| return false; |
| } |
| |
| /// Creates a scalar/vector integer constant. |
| static Value getScalarOrVectorConstInt(Type type, uint64_t value, |
| OpBuilder &builder, Location loc) { |
| if (auto vectorType = dyn_cast<VectorType>(type)) { |
| Attribute element = IntegerAttr::get(vectorType.getElementType(), value); |
| auto attr = SplatElementsAttr::get(vectorType, element); |
| return spirv::ConstantOp::create(builder, loc, vectorType, attr); |
| } |
| |
| if (auto intType = dyn_cast<IntegerType>(type)) |
| return spirv::ConstantOp::create(builder, loc, type, |
| builder.getIntegerAttr(type, value)); |
| |
| return nullptr; |
| } |
| |
| /// Returns true if scalar/vector type `a` and `b` have the same number of |
| /// bitwidth. |
| static bool hasSameBitwidth(Type a, Type b) { |
| auto getNumBitwidth = [](Type type) { |
| unsigned bw = 0; |
| if (type.isIntOrFloat()) |
| bw = type.getIntOrFloatBitWidth(); |
| else if (auto vecType = dyn_cast<VectorType>(type)) |
| bw = vecType.getElementTypeBitWidth() * vecType.getNumElements(); |
| return bw; |
| }; |
| unsigned aBW = getNumBitwidth(a); |
| unsigned bBW = getNumBitwidth(b); |
| return aBW != 0 && bBW != 0 && aBW == bBW; |
| } |
| |
| /// Returns a source type conversion failure for `srcType` and operation `op`. |
| static LogicalResult |
| getTypeConversionFailure(ConversionPatternRewriter &rewriter, Operation *op, |
| Type srcType) { |
| return rewriter.notifyMatchFailure( |
| op->getLoc(), |
| llvm::formatv("failed to convert source type '{0}'", srcType)); |
| } |
| |
| /// Returns a source type conversion failure for the result type of `op`. |
| static LogicalResult |
| getTypeConversionFailure(ConversionPatternRewriter &rewriter, Operation *op) { |
| assert(op->getNumResults() == 1); |
| return getTypeConversionFailure(rewriter, op, op->getResultTypes().front()); |
| } |
| |
| // TODO: Move to some common place? |
| static std::string getDecorationString(spirv::Decoration decor) { |
| return llvm::convertToSnakeFromCamelCase(stringifyDecoration(decor)); |
| } |
| |
| namespace { |
| |
| /// Converts elementwise unary, binary and ternary arith operations to SPIR-V |
| /// operations. Op can potentially support overflow flags. |
| template <typename Op, typename SPIRVOp> |
| struct ElementwiseArithOpPattern final : OpConversionPattern<Op> { |
| using OpConversionPattern<Op>::OpConversionPattern; |
| |
| LogicalResult |
| matchAndRewrite(Op op, typename Op::Adaptor adaptor, |
| ConversionPatternRewriter &rewriter) const override { |
| assert(adaptor.getOperands().size() <= 3); |
| auto converter = this->template getTypeConverter<SPIRVTypeConverter>(); |
| Type dstType = converter->convertType(op.getType()); |
| if (!dstType) { |
| return rewriter.notifyMatchFailure( |
| op->getLoc(), |
| llvm::formatv("failed to convert type {0} for SPIR-V", op.getType())); |
| } |
| |
| if (SPIRVOp::template hasTrait<OpTrait::spirv::UnsignedOp>() && |
| !getElementTypeOrSelf(op.getType()).isIndex() && |
| dstType != op.getType()) { |
| return op.emitError("bitwidth emulation is not implemented yet on " |
| "unsigned op pattern version"); |
| } |
| |
| auto overflowFlags = arith::IntegerOverflowFlags::none; |
| if (auto overflowIface = |
| dyn_cast<arith::ArithIntegerOverflowFlagsInterface>(*op)) { |
| if (converter->getTargetEnv().allows( |
| spirv::Extension::SPV_KHR_no_integer_wrap_decoration)) |
| overflowFlags = overflowIface.getOverflowAttr().getValue(); |
| } |
| |
| auto newOp = rewriter.template replaceOpWithNewOp<SPIRVOp>( |
| op, dstType, adaptor.getOperands()); |
| |
| if (bitEnumContainsAny(overflowFlags, arith::IntegerOverflowFlags::nsw)) |
| newOp->setAttr(getDecorationString(spirv::Decoration::NoSignedWrap), |
| rewriter.getUnitAttr()); |
| |
| if (bitEnumContainsAny(overflowFlags, arith::IntegerOverflowFlags::nuw)) |
| newOp->setAttr(getDecorationString(spirv::Decoration::NoUnsignedWrap), |
| rewriter.getUnitAttr()); |
| |
| return success(); |
| } |
| }; |
| |
| //===----------------------------------------------------------------------===// |
| // ConstantOp |
| //===----------------------------------------------------------------------===// |
| |
| /// Converts composite arith.constant operation to spirv.Constant. |
| struct ConstantCompositeOpPattern final |
| : public OpConversionPattern<arith::ConstantOp> { |
| using OpConversionPattern::OpConversionPattern; |
| |
| LogicalResult |
| matchAndRewrite(arith::ConstantOp constOp, OpAdaptor adaptor, |
| ConversionPatternRewriter &rewriter) const override { |
| auto srcType = dyn_cast<ShapedType>(constOp.getType()); |
| if (!srcType || srcType.getNumElements() == 1) |
| return failure(); |
| |
| // arith.constant should only have vector or tensor types. This is a MLIR |
| // wide problem at the moment. |
| if (!isa<VectorType, RankedTensorType>(srcType)) |
| return rewriter.notifyMatchFailure(constOp, "unsupported ShapedType"); |
| |
| Type dstType = getTypeConverter()->convertType(srcType); |
| if (!dstType) |
| return failure(); |
| |
| // Import the resource into the IR to make use of the special handling of |
| // element types later on. |
| mlir::DenseElementsAttr dstElementsAttr; |
| if (auto denseElementsAttr = |
| dyn_cast<DenseElementsAttr>(constOp.getValue())) { |
| dstElementsAttr = denseElementsAttr; |
| } else if (auto resourceAttr = |
| dyn_cast<DenseResourceElementsAttr>(constOp.getValue())) { |
| |
| AsmResourceBlob *blob = resourceAttr.getRawHandle().getBlob(); |
| if (!blob) |
| return constOp->emitError("could not find resource blob"); |
| |
| ArrayRef<char> ptr = blob->getData(); |
| |
| // Check that the buffer meets the requirements to get converted to a |
| // DenseElementsAttr |
| bool detectedSplat = false; |
| if (!DenseElementsAttr::isValidRawBuffer(srcType, ptr, detectedSplat)) |
| return constOp->emitError("resource is not a valid buffer"); |
| |
| dstElementsAttr = |
| DenseElementsAttr::getFromRawBuffer(resourceAttr.getType(), ptr); |
| } else { |
| return constOp->emitError("unsupported elements attribute"); |
| } |
| |
| ShapedType dstAttrType = dstElementsAttr.getType(); |
| |
| // If the composite type has more than one dimensions, perform |
| // linearization. |
| if (srcType.getRank() > 1) { |
| if (isa<RankedTensorType>(srcType)) { |
| dstAttrType = RankedTensorType::get(srcType.getNumElements(), |
| srcType.getElementType()); |
| dstElementsAttr = dstElementsAttr.reshape(dstAttrType); |
| } else { |
| // TODO: add support for large vectors. |
| return failure(); |
| } |
| } |
| |
| Type srcElemType = srcType.getElementType(); |
| Type dstElemType; |
| // Tensor types are converted to SPIR-V array types; vector types are |
| // converted to SPIR-V vector/array types. |
| if (auto arrayType = dyn_cast<spirv::ArrayType>(dstType)) |
| dstElemType = arrayType.getElementType(); |
| else |
| dstElemType = cast<VectorType>(dstType).getElementType(); |
| |
| // If the source and destination element types are different, perform |
| // attribute conversion. |
| if (srcElemType != dstElemType) { |
| SmallVector<Attribute, 8> elements; |
| if (isa<FloatType>(srcElemType)) { |
| for (FloatAttr srcAttr : dstElementsAttr.getValues<FloatAttr>()) { |
| Attribute dstAttr = nullptr; |
| // Handle 8-bit float conversion to 8-bit integer. |
| auto *typeConverter = getTypeConverter<SPIRVTypeConverter>(); |
| if (typeConverter->getOptions().emulateUnsupportedFloatTypes && |
| srcElemType.getIntOrFloatBitWidth() == 8 && |
| isa<IntegerType>(dstElemType)) { |
| dstAttr = |
| getIntegerAttrFromFloatAttr(srcAttr, dstElemType, rewriter); |
| } else { |
| dstAttr = convertFloatAttr(srcAttr, cast<FloatType>(dstElemType), |
| rewriter); |
| } |
| if (!dstAttr) |
| return failure(); |
| elements.push_back(dstAttr); |
| } |
| } else if (srcElemType.isInteger(1)) { |
| return failure(); |
| } else { |
| for (IntegerAttr srcAttr : dstElementsAttr.getValues<IntegerAttr>()) { |
| IntegerAttr dstAttr = convertIntegerAttr( |
| srcAttr, cast<IntegerType>(dstElemType), rewriter); |
| if (!dstAttr) |
| return failure(); |
| elements.push_back(dstAttr); |
| } |
| } |
| |
| // Unfortunately, we cannot use dialect-specific types for element |
| // attributes; element attributes only works with builtin types. So we |
| // need to prepare another converted builtin types for the destination |
| // elements attribute. |
| if (isa<RankedTensorType>(dstAttrType)) |
| dstAttrType = |
| RankedTensorType::get(dstAttrType.getShape(), dstElemType); |
| else |
| dstAttrType = VectorType::get(dstAttrType.getShape(), dstElemType); |
| |
| dstElementsAttr = DenseElementsAttr::get(dstAttrType, elements); |
| } |
| |
| rewriter.replaceOpWithNewOp<spirv::ConstantOp>(constOp, dstType, |
| dstElementsAttr); |
| return success(); |
| } |
| }; |
| |
| /// Converts scalar arith.constant operation to spirv.Constant. |
| struct ConstantScalarOpPattern final |
| : public OpConversionPattern<arith::ConstantOp> { |
| using OpConversionPattern::OpConversionPattern; |
| |
| LogicalResult |
| matchAndRewrite(arith::ConstantOp constOp, OpAdaptor adaptor, |
| ConversionPatternRewriter &rewriter) const override { |
| Type srcType = constOp.getType(); |
| if (auto shapedType = dyn_cast<ShapedType>(srcType)) { |
| if (shapedType.getNumElements() != 1) |
| return failure(); |
| srcType = shapedType.getElementType(); |
| } |
| if (!srcType.isIntOrIndexOrFloat()) |
| return failure(); |
| |
| Attribute cstAttr = constOp.getValue(); |
| if (auto elementsAttr = dyn_cast<DenseElementsAttr>(cstAttr)) |
| cstAttr = elementsAttr.getSplatValue<Attribute>(); |
| |
| Type dstType = getTypeConverter()->convertType(srcType); |
| if (!dstType) |
| return failure(); |
| |
| // Floating-point types. |
| if (isa<FloatType>(srcType)) { |
| auto srcAttr = cast<FloatAttr>(cstAttr); |
| Attribute dstAttr = srcAttr; |
| |
| // Floating-point types not supported in the target environment are all |
| // converted to float type. |
| auto *typeConverter = getTypeConverter<SPIRVTypeConverter>(); |
| if (typeConverter->getOptions().emulateUnsupportedFloatTypes && |
| srcType.getIntOrFloatBitWidth() == 8 && isa<IntegerType>(dstType) && |
| dstType.getIntOrFloatBitWidth() == 8) { |
| // If the source is an 8-bit float, convert it to a 8-bit integer. |
| dstAttr = getIntegerAttrFromFloatAttr(srcAttr, dstType, rewriter); |
| if (!dstAttr) |
| return failure(); |
| } else if (srcType != dstType) { |
| dstAttr = convertFloatAttr(srcAttr, cast<FloatType>(dstType), rewriter); |
| if (!dstAttr) |
| return failure(); |
| } |
| |
| rewriter.replaceOpWithNewOp<spirv::ConstantOp>(constOp, dstType, dstAttr); |
| return success(); |
| } |
| |
| // Bool type. |
| if (srcType.isInteger(1)) { |
| // arith.constant can use 0/1 instead of true/false for i1 values. We need |
| // to handle that here. |
| auto dstAttr = convertBoolAttr(cstAttr, rewriter); |
| if (!dstAttr) |
| return failure(); |
| rewriter.replaceOpWithNewOp<spirv::ConstantOp>(constOp, dstType, dstAttr); |
| return success(); |
| } |
| |
| // IndexType or IntegerType. Index values are converted to 32-bit integer |
| // values when converting to SPIR-V. |
| auto srcAttr = cast<IntegerAttr>(cstAttr); |
| IntegerAttr dstAttr = |
| convertIntegerAttr(srcAttr, cast<IntegerType>(dstType), rewriter); |
| if (!dstAttr) |
| return failure(); |
| rewriter.replaceOpWithNewOp<spirv::ConstantOp>(constOp, dstType, dstAttr); |
| return success(); |
| } |
| }; |
| |
| //===----------------------------------------------------------------------===// |
| // RemSIOp |
| //===----------------------------------------------------------------------===// |
| |
| /// Returns signed remainder for `lhs` and `rhs` and lets the result follow |
| /// the sign of `signOperand`. |
| /// |
| /// Note that this is needed for Vulkan. Per the Vulkan's SPIR-V environment |
| /// spec, "for the OpSRem and OpSMod instructions, if either operand is negative |
| /// the result is undefined." So we cannot directly use spirv.SRem/spirv.SMod |
| /// if either operand can be negative. Emulate it via spirv.UMod. |
| template <typename SignedAbsOp> |
| static Value emulateSignedRemainder(Location loc, Value lhs, Value rhs, |
| Value signOperand, OpBuilder &builder) { |
| assert(lhs.getType() == rhs.getType()); |
| assert(lhs == signOperand || rhs == signOperand); |
| |
| Type type = lhs.getType(); |
| |
| // Calculate the remainder with spirv.UMod. |
| Value lhsAbs = SignedAbsOp::create(builder, loc, type, lhs); |
| Value rhsAbs = SignedAbsOp::create(builder, loc, type, rhs); |
| Value abs = spirv::UModOp::create(builder, loc, lhsAbs, rhsAbs); |
| |
| // Fix the sign. |
| Value isPositive; |
| if (lhs == signOperand) |
| isPositive = spirv::IEqualOp::create(builder, loc, lhs, lhsAbs); |
| else |
| isPositive = spirv::IEqualOp::create(builder, loc, rhs, rhsAbs); |
| Value absNegate = spirv::SNegateOp::create(builder, loc, type, abs); |
| return spirv::SelectOp::create(builder, loc, type, isPositive, abs, |
| absNegate); |
| } |
| |
| /// Converts arith.remsi to GLSL SPIR-V ops. |
| /// |
| /// This cannot be merged into the template unary/binary pattern due to Vulkan |
| /// restrictions over spirv.SRem and spirv.SMod. |
| struct RemSIOpGLPattern final : public OpConversionPattern<arith::RemSIOp> { |
| using OpConversionPattern::OpConversionPattern; |
| |
| LogicalResult |
| matchAndRewrite(arith::RemSIOp op, OpAdaptor adaptor, |
| ConversionPatternRewriter &rewriter) const override { |
| Value result = emulateSignedRemainder<spirv::CLSAbsOp>( |
| op.getLoc(), adaptor.getOperands()[0], adaptor.getOperands()[1], |
| adaptor.getOperands()[0], rewriter); |
| rewriter.replaceOp(op, result); |
| |
| return success(); |
| } |
| }; |
| |
| /// Converts arith.remsi to OpenCL SPIR-V ops. |
| struct RemSIOpCLPattern final : public OpConversionPattern<arith::RemSIOp> { |
| using OpConversionPattern::OpConversionPattern; |
| |
| LogicalResult |
| matchAndRewrite(arith::RemSIOp op, OpAdaptor adaptor, |
| ConversionPatternRewriter &rewriter) const override { |
| Value result = emulateSignedRemainder<spirv::GLSAbsOp>( |
| op.getLoc(), adaptor.getOperands()[0], adaptor.getOperands()[1], |
| adaptor.getOperands()[0], rewriter); |
| rewriter.replaceOp(op, result); |
| |
| return success(); |
| } |
| }; |
| |
| //===----------------------------------------------------------------------===// |
| // BitwiseOp |
| //===----------------------------------------------------------------------===// |
| |
| /// Converts bitwise operations to SPIR-V operations. This is a special pattern |
| /// other than the BinaryOpPatternPattern because if the operands are boolean |
| /// values, SPIR-V uses different operations (`SPIRVLogicalOp`). For |
| /// non-boolean operands, SPIR-V should use `SPIRVBitwiseOp`. |
| template <typename Op, typename SPIRVLogicalOp, typename SPIRVBitwiseOp> |
| struct BitwiseOpPattern final : public OpConversionPattern<Op> { |
| using OpConversionPattern<Op>::OpConversionPattern; |
| |
| LogicalResult |
| matchAndRewrite(Op op, typename Op::Adaptor adaptor, |
| ConversionPatternRewriter &rewriter) const override { |
| assert(adaptor.getOperands().size() == 2); |
| Type dstType = this->getTypeConverter()->convertType(op.getType()); |
| if (!dstType) |
| return getTypeConversionFailure(rewriter, op); |
| |
| if (isBoolScalarOrVector(adaptor.getOperands().front().getType())) { |
| rewriter.template replaceOpWithNewOp<SPIRVLogicalOp>( |
| op, dstType, adaptor.getOperands()); |
| } else { |
| rewriter.template replaceOpWithNewOp<SPIRVBitwiseOp>( |
| op, dstType, adaptor.getOperands()); |
| } |
| return success(); |
| } |
| }; |
| |
| //===----------------------------------------------------------------------===// |
| // XOrIOp |
| //===----------------------------------------------------------------------===// |
| |
| /// Converts arith.xori to SPIR-V operations. |
| struct XOrIOpLogicalPattern final : public OpConversionPattern<arith::XOrIOp> { |
| using OpConversionPattern::OpConversionPattern; |
| |
| LogicalResult |
| matchAndRewrite(arith::XOrIOp op, OpAdaptor adaptor, |
| ConversionPatternRewriter &rewriter) const override { |
| assert(adaptor.getOperands().size() == 2); |
| |
| if (isBoolScalarOrVector(adaptor.getOperands().front().getType())) |
| return failure(); |
| |
| Type dstType = getTypeConverter()->convertType(op.getType()); |
| if (!dstType) |
| return getTypeConversionFailure(rewriter, op); |
| |
| rewriter.replaceOpWithNewOp<spirv::BitwiseXorOp>(op, dstType, |
| adaptor.getOperands()); |
| |
| return success(); |
| } |
| }; |
| |
| /// Converts arith.xori to SPIR-V operations if the type of source is i1 or |
| /// vector of i1. |
| struct XOrIOpBooleanPattern final : public OpConversionPattern<arith::XOrIOp> { |
| using OpConversionPattern::OpConversionPattern; |
| |
| LogicalResult |
| matchAndRewrite(arith::XOrIOp op, OpAdaptor adaptor, |
| ConversionPatternRewriter &rewriter) const override { |
| assert(adaptor.getOperands().size() == 2); |
| |
| if (!isBoolScalarOrVector(adaptor.getOperands().front().getType())) |
| return failure(); |
| |
| Type dstType = getTypeConverter()->convertType(op.getType()); |
| if (!dstType) |
| return getTypeConversionFailure(rewriter, op); |
| |
| rewriter.replaceOpWithNewOp<spirv::LogicalNotEqualOp>( |
| op, dstType, adaptor.getOperands()); |
| return success(); |
| } |
| }; |
| |
| //===----------------------------------------------------------------------===// |
| // UIToFPOp |
| //===----------------------------------------------------------------------===// |
| |
| /// Converts arith.uitofp to spirv.Select if the type of source is i1 or vector |
| /// of i1. |
| struct UIToFPI1Pattern final : public OpConversionPattern<arith::UIToFPOp> { |
| using OpConversionPattern::OpConversionPattern; |
| |
| LogicalResult |
| matchAndRewrite(arith::UIToFPOp op, OpAdaptor adaptor, |
| ConversionPatternRewriter &rewriter) const override { |
| Type srcType = adaptor.getOperands().front().getType(); |
| if (!isBoolScalarOrVector(srcType)) |
| return failure(); |
| |
| Type dstType = getTypeConverter()->convertType(op.getType()); |
| if (!dstType) |
| return getTypeConversionFailure(rewriter, op); |
| |
| Location loc = op.getLoc(); |
| Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter); |
| Value one = spirv::ConstantOp::getOne(dstType, loc, rewriter); |
| rewriter.replaceOpWithNewOp<spirv::SelectOp>( |
| op, dstType, adaptor.getOperands().front(), one, zero); |
| return success(); |
| } |
| }; |
| |
| //===----------------------------------------------------------------------===// |
| // IndexCastOp |
| //===----------------------------------------------------------------------===// |
| |
| /// Converts arith.index_cast to spirv.INotEqual if the target type is i1. |
| struct IndexCastIndexI1Pattern final |
| : public OpConversionPattern<arith::IndexCastOp> { |
| using OpConversionPattern::OpConversionPattern; |
| |
| LogicalResult |
| matchAndRewrite(arith::IndexCastOp op, OpAdaptor adaptor, |
| ConversionPatternRewriter &rewriter) const override { |
| if (!isBoolScalarOrVector(op.getType())) |
| return failure(); |
| |
| Type dstType = getTypeConverter()->convertType(op.getType()); |
| if (!dstType) |
| return getTypeConversionFailure(rewriter, op); |
| |
| Location loc = op.getLoc(); |
| Value zeroIdx = |
| spirv::ConstantOp::getZero(adaptor.getIn().getType(), loc, rewriter); |
| rewriter.replaceOpWithNewOp<spirv::INotEqualOp>(op, dstType, zeroIdx, |
| adaptor.getIn()); |
| return success(); |
| } |
| }; |
| |
| /// Converts arith.index_cast to spirv.Select if the source type is i1. |
| struct IndexCastI1IndexPattern final |
| : public OpConversionPattern<arith::IndexCastOp> { |
| using OpConversionPattern::OpConversionPattern; |
| |
| LogicalResult |
| matchAndRewrite(arith::IndexCastOp op, OpAdaptor adaptor, |
| ConversionPatternRewriter &rewriter) const override { |
| if (!isBoolScalarOrVector(adaptor.getIn().getType())) |
| return failure(); |
| |
| Type dstType = getTypeConverter()->convertType(op.getType()); |
| if (!dstType) |
| return getTypeConversionFailure(rewriter, op); |
| |
| Location loc = op.getLoc(); |
| Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter); |
| Value one = spirv::ConstantOp::getOne(dstType, loc, rewriter); |
| rewriter.replaceOpWithNewOp<spirv::SelectOp>(op, dstType, adaptor.getIn(), |
| one, zero); |
| return success(); |
| } |
| }; |
| |
| //===----------------------------------------------------------------------===// |
| // ExtSIOp |
| //===----------------------------------------------------------------------===// |
| |
| /// Converts arith.extsi to spirv.Select if the type of source is i1 or vector |
| /// of i1. |
| struct ExtSII1Pattern final : public OpConversionPattern<arith::ExtSIOp> { |
| using OpConversionPattern::OpConversionPattern; |
| |
| LogicalResult |
| matchAndRewrite(arith::ExtSIOp op, OpAdaptor adaptor, |
| ConversionPatternRewriter &rewriter) const override { |
| Value operand = adaptor.getIn(); |
| if (!isBoolScalarOrVector(operand.getType())) |
| return failure(); |
| |
| Location loc = op.getLoc(); |
| Type dstType = getTypeConverter()->convertType(op.getType()); |
| if (!dstType) |
| return getTypeConversionFailure(rewriter, op); |
| |
| Value allOnes; |
| if (auto intTy = dyn_cast<IntegerType>(dstType)) { |
| unsigned componentBitwidth = intTy.getWidth(); |
| allOnes = spirv::ConstantOp::create( |
| rewriter, loc, intTy, |
| rewriter.getIntegerAttr(intTy, APInt::getAllOnes(componentBitwidth))); |
| } else if (auto vectorTy = dyn_cast<VectorType>(dstType)) { |
| unsigned componentBitwidth = vectorTy.getElementTypeBitWidth(); |
| allOnes = spirv::ConstantOp::create( |
| rewriter, loc, vectorTy, |
| SplatElementsAttr::get(vectorTy, |
| APInt::getAllOnes(componentBitwidth))); |
| } else { |
| return rewriter.notifyMatchFailure( |
| loc, llvm::formatv("unhandled type: {0}", dstType)); |
| } |
| |
| Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter); |
| rewriter.replaceOpWithNewOp<spirv::SelectOp>(op, dstType, operand, allOnes, |
| zero); |
| return success(); |
| } |
| }; |
| |
| /// Converts arith.extsi to spirv.Select if the type of source is neither i1 nor |
| /// vector of i1. |
| struct ExtSIPattern final : public OpConversionPattern<arith::ExtSIOp> { |
| using OpConversionPattern::OpConversionPattern; |
| |
| LogicalResult |
| matchAndRewrite(arith::ExtSIOp op, OpAdaptor adaptor, |
| ConversionPatternRewriter &rewriter) const override { |
| Type srcType = adaptor.getIn().getType(); |
| if (isBoolScalarOrVector(srcType)) |
| return failure(); |
| |
| Type dstType = getTypeConverter()->convertType(op.getType()); |
| if (!dstType) |
| return getTypeConversionFailure(rewriter, op); |
| |
| if (dstType == srcType) { |
| // We can have the same source and destination type due to type emulation. |
| // Perform bit shifting to make sure we have the proper leading set bits. |
| |
| unsigned srcBW = |
| getElementTypeOrSelf(op.getIn().getType()).getIntOrFloatBitWidth(); |
| unsigned dstBW = |
| getElementTypeOrSelf(op.getType()).getIntOrFloatBitWidth(); |
| assert(srcBW < dstBW); |
| Value shiftSize = getScalarOrVectorConstInt(dstType, dstBW - srcBW, |
| rewriter, op.getLoc()); |
| |
| // First shift left to sequeeze out all leading bits beyond the original |
| // bitwidth. Here we need to use the original source and result type's |
| // bitwidth. |
| auto shiftLOp = spirv::ShiftLeftLogicalOp::create( |
| rewriter, op.getLoc(), dstType, adaptor.getIn(), shiftSize); |
| |
| // Then we perform arithmetic right shift to make sure we have the right |
| // sign bits for negative values. |
| rewriter.replaceOpWithNewOp<spirv::ShiftRightArithmeticOp>( |
| op, dstType, shiftLOp, shiftSize); |
| } else { |
| rewriter.replaceOpWithNewOp<spirv::SConvertOp>(op, dstType, |
| adaptor.getOperands()); |
| } |
| |
| return success(); |
| } |
| }; |
| |
| //===----------------------------------------------------------------------===// |
| // ExtUIOp |
| //===----------------------------------------------------------------------===// |
| |
| /// Converts arith.extui to spirv.Select if the type of source is i1 or vector |
| /// of i1. |
| struct ExtUII1Pattern final : public OpConversionPattern<arith::ExtUIOp> { |
| using OpConversionPattern::OpConversionPattern; |
| |
| LogicalResult |
| matchAndRewrite(arith::ExtUIOp op, OpAdaptor adaptor, |
| ConversionPatternRewriter &rewriter) const override { |
| Type srcType = adaptor.getOperands().front().getType(); |
| if (!isBoolScalarOrVector(srcType)) |
| return failure(); |
| |
| Type dstType = getTypeConverter()->convertType(op.getType()); |
| if (!dstType) |
| return getTypeConversionFailure(rewriter, op); |
| |
| Location loc = op.getLoc(); |
| Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter); |
| Value one = spirv::ConstantOp::getOne(dstType, loc, rewriter); |
| rewriter.replaceOpWithNewOp<spirv::SelectOp>( |
| op, dstType, adaptor.getOperands().front(), one, zero); |
| return success(); |
| } |
| }; |
| |
| /// Converts arith.extui for cases where the type of source is neither i1 nor |
| /// vector of i1. |
| struct ExtUIPattern final : public OpConversionPattern<arith::ExtUIOp> { |
| using OpConversionPattern::OpConversionPattern; |
| |
| LogicalResult |
| matchAndRewrite(arith::ExtUIOp op, OpAdaptor adaptor, |
| ConversionPatternRewriter &rewriter) const override { |
| Type srcType = adaptor.getIn().getType(); |
| if (isBoolScalarOrVector(srcType)) |
| return failure(); |
| |
| Type dstType = getTypeConverter()->convertType(op.getType()); |
| if (!dstType) |
| return getTypeConversionFailure(rewriter, op); |
| |
| if (dstType == srcType) { |
| // We can have the same source and destination type due to type emulation. |
| // Perform bit masking to make sure we don't pollute downstream consumers |
| // with unwanted bits. Here we need to use the original source type's |
| // bitwidth. |
| unsigned bitwidth = |
| getElementTypeOrSelf(op.getIn().getType()).getIntOrFloatBitWidth(); |
| Value mask = getScalarOrVectorConstInt( |
| dstType, llvm::maskTrailingOnes<uint64_t>(bitwidth), rewriter, |
| op.getLoc()); |
| rewriter.replaceOpWithNewOp<spirv::BitwiseAndOp>(op, dstType, |
| adaptor.getIn(), mask); |
| } else { |
| rewriter.replaceOpWithNewOp<spirv::UConvertOp>(op, dstType, |
| adaptor.getOperands()); |
| } |
| return success(); |
| } |
| }; |
| |
| //===----------------------------------------------------------------------===// |
| // TruncIOp |
| //===----------------------------------------------------------------------===// |
| |
| /// Converts arith.trunci to spirv.Select if the type of result is i1 or vector |
| /// of i1. |
| struct TruncII1Pattern final : public OpConversionPattern<arith::TruncIOp> { |
| using OpConversionPattern::OpConversionPattern; |
| |
| LogicalResult |
| matchAndRewrite(arith::TruncIOp op, OpAdaptor adaptor, |
| ConversionPatternRewriter &rewriter) const override { |
| Type dstType = getTypeConverter()->convertType(op.getType()); |
| if (!dstType) |
| return getTypeConversionFailure(rewriter, op); |
| |
| if (!isBoolScalarOrVector(dstType)) |
| return failure(); |
| |
| Location loc = op.getLoc(); |
| auto srcType = adaptor.getOperands().front().getType(); |
| // Check if (x & 1) == 1. |
| Value mask = spirv::ConstantOp::getOne(srcType, loc, rewriter); |
| Value maskedSrc = spirv::BitwiseAndOp::create( |
| rewriter, loc, srcType, adaptor.getOperands()[0], mask); |
| Value isOne = spirv::IEqualOp::create(rewriter, loc, maskedSrc, mask); |
| |
| Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter); |
| Value one = spirv::ConstantOp::getOne(dstType, loc, rewriter); |
| rewriter.replaceOpWithNewOp<spirv::SelectOp>(op, dstType, isOne, one, zero); |
| return success(); |
| } |
| }; |
| |
| /// Converts arith.trunci for cases where the type of result is neither i1 |
| /// nor vector of i1. |
| struct TruncIPattern final : public OpConversionPattern<arith::TruncIOp> { |
| using OpConversionPattern::OpConversionPattern; |
| |
| LogicalResult |
| matchAndRewrite(arith::TruncIOp op, OpAdaptor adaptor, |
| ConversionPatternRewriter &rewriter) const override { |
| Type srcType = adaptor.getIn().getType(); |
| Type dstType = getTypeConverter()->convertType(op.getType()); |
| if (!dstType) |
| return getTypeConversionFailure(rewriter, op); |
| |
| if (isBoolScalarOrVector(dstType)) |
| return failure(); |
| |
| if (dstType == srcType) { |
| // We can have the same source and destination type due to type emulation. |
| // Perform bit masking to make sure we don't pollute downstream consumers |
| // with unwanted bits. Here we need to use the original result type's |
| // bitwidth. |
| unsigned bw = getElementTypeOrSelf(op.getType()).getIntOrFloatBitWidth(); |
| Value mask = getScalarOrVectorConstInt( |
| dstType, llvm::maskTrailingOnes<uint64_t>(bw), rewriter, op.getLoc()); |
| rewriter.replaceOpWithNewOp<spirv::BitwiseAndOp>(op, dstType, |
| adaptor.getIn(), mask); |
| } else { |
| // Given this is truncation, either SConvertOp or UConvertOp works. |
| rewriter.replaceOpWithNewOp<spirv::SConvertOp>(op, dstType, |
| adaptor.getOperands()); |
| } |
| return success(); |
| } |
| }; |
| |
| //===----------------------------------------------------------------------===// |
| // TypeCastingOp |
| //===----------------------------------------------------------------------===// |
| |
| static std::optional<spirv::FPRoundingMode> |
| convertArithRoundingModeToSPIRV(arith::RoundingMode roundingMode) { |
| switch (roundingMode) { |
| case arith::RoundingMode::downward: |
| return spirv::FPRoundingMode::RTN; |
| case arith::RoundingMode::to_nearest_even: |
| return spirv::FPRoundingMode::RTE; |
| case arith::RoundingMode::toward_zero: |
| return spirv::FPRoundingMode::RTZ; |
| case arith::RoundingMode::upward: |
| return spirv::FPRoundingMode::RTP; |
| case arith::RoundingMode::to_nearest_away: |
| // SPIR-V FPRoundingMode decoration has no ties-away-from-zero mode |
| // (as of SPIR-V 1.6) |
| return std::nullopt; |
| } |
| llvm_unreachable("Unhandled rounding mode"); |
| } |
| |
| /// Converts type-casting standard operations to SPIR-V operations. |
| template <typename Op, typename SPIRVOp> |
| struct TypeCastingOpPattern final : public OpConversionPattern<Op> { |
| using OpConversionPattern<Op>::OpConversionPattern; |
| |
| LogicalResult |
| matchAndRewrite(Op op, typename Op::Adaptor adaptor, |
| ConversionPatternRewriter &rewriter) const override { |
| Type srcType = llvm::getSingleElement(adaptor.getOperands()).getType(); |
| Type dstType = this->getTypeConverter()->convertType(op.getType()); |
| if (!dstType) |
| return getTypeConversionFailure(rewriter, op); |
| |
| if (isBoolScalarOrVector(srcType) || isBoolScalarOrVector(dstType)) |
| return failure(); |
| |
| if (dstType == srcType) { |
| // Due to type conversion, we are seeing the same source and target type. |
| // Then we can just erase this operation by forwarding its operand. |
| rewriter.replaceOp(op, adaptor.getOperands().front()); |
| } else { |
| // Compute new rounding mode (if any). |
| std::optional<spirv::FPRoundingMode> rm = std::nullopt; |
| if (auto roundingModeOp = |
| dyn_cast<arith::ArithRoundingModeInterface>(*op)) { |
| if (arith::RoundingModeAttr roundingMode = |
| roundingModeOp.getRoundingModeAttr()) { |
| if (!(rm = |
| convertArithRoundingModeToSPIRV(roundingMode.getValue()))) { |
| return rewriter.notifyMatchFailure( |
| op->getLoc(), |
| llvm::formatv("unsupported rounding mode '{0}'", roundingMode)); |
| } |
| } |
| } |
| // Create replacement op and attach rounding mode attribute (if any). |
| auto newOp = rewriter.template replaceOpWithNewOp<SPIRVOp>( |
| op, dstType, adaptor.getOperands()); |
| if (rm) { |
| newOp->setAttr( |
| getDecorationString(spirv::Decoration::FPRoundingMode), |
| spirv::FPRoundingModeAttr::get(rewriter.getContext(), *rm)); |
| } |
| } |
| return success(); |
| } |
| }; |
| |
| //===----------------------------------------------------------------------===// |
| // CmpIOp |
| //===----------------------------------------------------------------------===// |
| |
| /// Converts integer compare operation on i1 type operands to SPIR-V ops. |
| class CmpIOpBooleanPattern final : public OpConversionPattern<arith::CmpIOp> { |
| public: |
| using OpConversionPattern::OpConversionPattern; |
| |
| LogicalResult |
| matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor, |
| ConversionPatternRewriter &rewriter) const override { |
| Type srcType = op.getLhs().getType(); |
| if (!isBoolScalarOrVector(srcType)) |
| return failure(); |
| Type dstType = getTypeConverter()->convertType(srcType); |
| if (!dstType) |
| return getTypeConversionFailure(rewriter, op, srcType); |
| |
| switch (op.getPredicate()) { |
| case arith::CmpIPredicate::eq: { |
| rewriter.replaceOpWithNewOp<spirv::LogicalEqualOp>(op, adaptor.getLhs(), |
| adaptor.getRhs()); |
| return success(); |
| } |
| case arith::CmpIPredicate::ne: { |
| rewriter.replaceOpWithNewOp<spirv::LogicalNotEqualOp>( |
| op, adaptor.getLhs(), adaptor.getRhs()); |
| return success(); |
| } |
| case arith::CmpIPredicate::uge: |
| case arith::CmpIPredicate::ugt: |
| case arith::CmpIPredicate::ule: |
| case arith::CmpIPredicate::ult: { |
| // There are no direct corresponding instructions in SPIR-V for such |
| // cases. Extend them to 32-bit and do comparision then. |
| Type type = rewriter.getI32Type(); |
| if (auto vectorType = dyn_cast<VectorType>(dstType)) |
| type = VectorType::get(vectorType.getShape(), type); |
| Value extLhs = |
| arith::ExtUIOp::create(rewriter, op.getLoc(), type, adaptor.getLhs()); |
| Value extRhs = |
| arith::ExtUIOp::create(rewriter, op.getLoc(), type, adaptor.getRhs()); |
| |
| rewriter.replaceOpWithNewOp<arith::CmpIOp>(op, op.getPredicate(), extLhs, |
| extRhs); |
| return success(); |
| } |
| default: |
| break; |
| } |
| return failure(); |
| } |
| }; |
| |
| /// Converts integer compare operation to SPIR-V ops. |
| class CmpIOpPattern final : public OpConversionPattern<arith::CmpIOp> { |
| public: |
| using OpConversionPattern::OpConversionPattern; |
| |
| LogicalResult |
| matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor, |
| ConversionPatternRewriter &rewriter) const override { |
| Type srcType = op.getLhs().getType(); |
| if (isBoolScalarOrVector(srcType)) |
| return failure(); |
| Type dstType = getTypeConverter()->convertType(srcType); |
| if (!dstType) |
| return getTypeConversionFailure(rewriter, op, srcType); |
| |
| switch (op.getPredicate()) { |
| #define DISPATCH(cmpPredicate, spirvOp) \ |
| case cmpPredicate: \ |
| if (spirvOp::template hasTrait<OpTrait::spirv::UnsignedOp>() && \ |
| !getElementTypeOrSelf(srcType).isIndex() && srcType != dstType && \ |
| !hasSameBitwidth(srcType, dstType)) { \ |
| return op.emitError( \ |
| "bitwidth emulation is not implemented yet on unsigned op"); \ |
| } \ |
| rewriter.replaceOpWithNewOp<spirvOp>(op, adaptor.getLhs(), \ |
| adaptor.getRhs()); \ |
| return success(); |
| |
| DISPATCH(arith::CmpIPredicate::eq, spirv::IEqualOp); |
| DISPATCH(arith::CmpIPredicate::ne, spirv::INotEqualOp); |
| DISPATCH(arith::CmpIPredicate::slt, spirv::SLessThanOp); |
| DISPATCH(arith::CmpIPredicate::sle, spirv::SLessThanEqualOp); |
| DISPATCH(arith::CmpIPredicate::sgt, spirv::SGreaterThanOp); |
| DISPATCH(arith::CmpIPredicate::sge, spirv::SGreaterThanEqualOp); |
| DISPATCH(arith::CmpIPredicate::ult, spirv::ULessThanOp); |
| DISPATCH(arith::CmpIPredicate::ule, spirv::ULessThanEqualOp); |
| DISPATCH(arith::CmpIPredicate::ugt, spirv::UGreaterThanOp); |
| DISPATCH(arith::CmpIPredicate::uge, spirv::UGreaterThanEqualOp); |
| |
| #undef DISPATCH |
| } |
| return failure(); |
| } |
| }; |
| |
| //===----------------------------------------------------------------------===// |
| // CmpFOpPattern |
| //===----------------------------------------------------------------------===// |
| |
| /// Converts floating-point comparison operations to SPIR-V ops. |
| class CmpFOpPattern final : public OpConversionPattern<arith::CmpFOp> { |
| public: |
| using OpConversionPattern::OpConversionPattern; |
| |
| LogicalResult |
| matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor, |
| ConversionPatternRewriter &rewriter) const override { |
| switch (op.getPredicate()) { |
| #define DISPATCH(cmpPredicate, spirvOp) \ |
| case cmpPredicate: \ |
| rewriter.replaceOpWithNewOp<spirvOp>(op, adaptor.getLhs(), \ |
| adaptor.getRhs()); \ |
| return success(); |
| |
| // Ordered. |
| DISPATCH(arith::CmpFPredicate::OEQ, spirv::FOrdEqualOp); |
| DISPATCH(arith::CmpFPredicate::OGT, spirv::FOrdGreaterThanOp); |
| DISPATCH(arith::CmpFPredicate::OGE, spirv::FOrdGreaterThanEqualOp); |
| DISPATCH(arith::CmpFPredicate::OLT, spirv::FOrdLessThanOp); |
| DISPATCH(arith::CmpFPredicate::OLE, spirv::FOrdLessThanEqualOp); |
| DISPATCH(arith::CmpFPredicate::ONE, spirv::FOrdNotEqualOp); |
| // Unordered. |
| DISPATCH(arith::CmpFPredicate::UEQ, spirv::FUnordEqualOp); |
| DISPATCH(arith::CmpFPredicate::UGT, spirv::FUnordGreaterThanOp); |
| DISPATCH(arith::CmpFPredicate::UGE, spirv::FUnordGreaterThanEqualOp); |
| DISPATCH(arith::CmpFPredicate::ULT, spirv::FUnordLessThanOp); |
| DISPATCH(arith::CmpFPredicate::ULE, spirv::FUnordLessThanEqualOp); |
| DISPATCH(arith::CmpFPredicate::UNE, spirv::FUnordNotEqualOp); |
| |
| #undef DISPATCH |
| |
| default: |
| break; |
| } |
| return failure(); |
| } |
| }; |
| |
| /// Converts floating point NaN check to SPIR-V ops. This pattern requires |
| /// Kernel capability. |
| class CmpFOpNanKernelPattern final : public OpConversionPattern<arith::CmpFOp> { |
| public: |
| using OpConversionPattern::OpConversionPattern; |
| |
| LogicalResult |
| matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor, |
| ConversionPatternRewriter &rewriter) const override { |
| if (op.getPredicate() == arith::CmpFPredicate::ORD) { |
| rewriter.replaceOpWithNewOp<spirv::OrderedOp>(op, adaptor.getLhs(), |
| adaptor.getRhs()); |
| return success(); |
| } |
| |
| if (op.getPredicate() == arith::CmpFPredicate::UNO) { |
| rewriter.replaceOpWithNewOp<spirv::UnorderedOp>(op, adaptor.getLhs(), |
| adaptor.getRhs()); |
| return success(); |
| } |
| |
| return failure(); |
| } |
| }; |
| |
| /// Converts floating point NaN check to SPIR-V ops. This pattern does not |
| /// require additional capability. |
| class CmpFOpNanNonePattern final : public OpConversionPattern<arith::CmpFOp> { |
| public: |
| using OpConversionPattern<arith::CmpFOp>::OpConversionPattern; |
| |
| LogicalResult |
| matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor, |
| ConversionPatternRewriter &rewriter) const override { |
| if (op.getPredicate() != arith::CmpFPredicate::ORD && |
| op.getPredicate() != arith::CmpFPredicate::UNO) |
| return failure(); |
| |
| Location loc = op.getLoc(); |
| |
| Value replace; |
| if (bitEnumContainsAll(op.getFastmath(), arith::FastMathFlags::nnan)) { |
| if (op.getPredicate() == arith::CmpFPredicate::ORD) { |
| // Ordered comparsion checks if neither operand is NaN. |
| replace = spirv::ConstantOp::getOne(op.getType(), loc, rewriter); |
| } else { |
| // Unordered comparsion checks if either operand is NaN. |
| replace = spirv::ConstantOp::getZero(op.getType(), loc, rewriter); |
| } |
| } else { |
| Value lhsIsNan = spirv::IsNanOp::create(rewriter, loc, adaptor.getLhs()); |
| Value rhsIsNan = spirv::IsNanOp::create(rewriter, loc, adaptor.getRhs()); |
| |
| replace = spirv::LogicalOrOp::create(rewriter, loc, lhsIsNan, rhsIsNan); |
| if (op.getPredicate() == arith::CmpFPredicate::ORD) |
| replace = spirv::LogicalNotOp::create(rewriter, loc, replace); |
| } |
| |
| rewriter.replaceOp(op, replace); |
| return success(); |
| } |
| }; |
| |
| //===----------------------------------------------------------------------===// |
| // AddUIExtendedOp |
| //===----------------------------------------------------------------------===// |
| |
| /// Converts arith.addui_extended to spirv.IAddCarry. |
| class AddUIExtendedOpPattern final |
| : public OpConversionPattern<arith::AddUIExtendedOp> { |
| public: |
| using OpConversionPattern::OpConversionPattern; |
| LogicalResult |
| matchAndRewrite(arith::AddUIExtendedOp op, OpAdaptor adaptor, |
| ConversionPatternRewriter &rewriter) const override { |
| Type dstElemTy = adaptor.getLhs().getType(); |
| Location loc = op->getLoc(); |
| Value result = spirv::IAddCarryOp::create(rewriter, loc, adaptor.getLhs(), |
| adaptor.getRhs()); |
| |
| Value sumResult = spirv::CompositeExtractOp::create(rewriter, loc, result, |
| llvm::ArrayRef(0)); |
| Value carryValue = spirv::CompositeExtractOp::create(rewriter, loc, result, |
| llvm::ArrayRef(1)); |
| |
| // Convert the carry value to boolean. |
| Value one = spirv::ConstantOp::getOne(dstElemTy, loc, rewriter); |
| Value carryResult = spirv::IEqualOp::create(rewriter, loc, carryValue, one); |
| |
| rewriter.replaceOp(op, {sumResult, carryResult}); |
| return success(); |
| } |
| }; |
| |
| //===----------------------------------------------------------------------===// |
| // MulIExtendedOp |
| //===----------------------------------------------------------------------===// |
| |
| /// Converts arith.mul*i_extended to spirv.*MulExtended. |
| template <typename ArithMulOp, typename SPIRVMulOp> |
| class MulIExtendedOpPattern final : public OpConversionPattern<ArithMulOp> { |
| public: |
| using OpConversionPattern<ArithMulOp>::OpConversionPattern; |
| LogicalResult |
| matchAndRewrite(ArithMulOp op, typename ArithMulOp::Adaptor adaptor, |
| ConversionPatternRewriter &rewriter) const override { |
| Location loc = op->getLoc(); |
| Value result = |
| SPIRVMulOp::create(rewriter, loc, adaptor.getLhs(), adaptor.getRhs()); |
| |
| Value low = spirv::CompositeExtractOp::create(rewriter, loc, result, |
| llvm::ArrayRef(0)); |
| Value high = spirv::CompositeExtractOp::create(rewriter, loc, result, |
| llvm::ArrayRef(1)); |
| |
| rewriter.replaceOp(op, {low, high}); |
| return success(); |
| } |
| }; |
| |
| //===----------------------------------------------------------------------===// |
| // SelectOp |
| //===----------------------------------------------------------------------===// |
| |
| /// Converts arith.select to spirv.Select. |
| class SelectOpPattern final : public OpConversionPattern<arith::SelectOp> { |
| public: |
| using OpConversionPattern::OpConversionPattern; |
| LogicalResult |
| matchAndRewrite(arith::SelectOp op, OpAdaptor adaptor, |
| ConversionPatternRewriter &rewriter) const override { |
| rewriter.replaceOpWithNewOp<spirv::SelectOp>(op, adaptor.getCondition(), |
| adaptor.getTrueValue(), |
| adaptor.getFalseValue()); |
| return success(); |
| } |
| }; |
| |
| //===----------------------------------------------------------------------===// |
| // MinimumFOp, MaximumFOp |
| //===----------------------------------------------------------------------===// |
| |
| /// Converts arith.maximumf/minimumf to spirv.GL.FMax/FMin or |
| /// spirv.CL.fmax/fmin. |
| template <typename Op, typename SPIRVOp> |
| class MinimumMaximumFOpPattern final : public OpConversionPattern<Op> { |
| public: |
| using OpConversionPattern<Op>::OpConversionPattern; |
| LogicalResult |
| matchAndRewrite(Op op, typename Op::Adaptor adaptor, |
| ConversionPatternRewriter &rewriter) const override { |
| auto *converter = this->template getTypeConverter<SPIRVTypeConverter>(); |
| Type dstType = converter->convertType(op.getType()); |
| if (!dstType) |
| return getTypeConversionFailure(rewriter, op); |
| |
| // arith.maximumf/minimumf: |
| // "if one of the arguments is NaN, then the result is also NaN." |
| // spirv.GL.FMax/FMin |
| // "which operand is the result is undefined if one of the operands |
| // is a NaN." |
| // spirv.CL.fmax/fmin: |
| // "If one argument is a NaN, Fmin returns the other argument." |
| |
| Location loc = op.getLoc(); |
| Value spirvOp = |
| SPIRVOp::create(rewriter, loc, dstType, adaptor.getOperands()); |
| |
| if (bitEnumContainsAll(op.getFastmath(), arith::FastMathFlags::nnan)) { |
| rewriter.replaceOp(op, spirvOp); |
| return success(); |
| } |
| |
| Value lhsIsNan = spirv::IsNanOp::create(rewriter, loc, adaptor.getLhs()); |
| Value rhsIsNan = spirv::IsNanOp::create(rewriter, loc, adaptor.getRhs()); |
| |
| Value select1 = spirv::SelectOp::create(rewriter, loc, dstType, lhsIsNan, |
| adaptor.getLhs(), spirvOp); |
| Value select2 = spirv::SelectOp::create(rewriter, loc, dstType, rhsIsNan, |
| adaptor.getRhs(), select1); |
| |
| rewriter.replaceOp(op, select2); |
| return success(); |
| } |
| }; |
| |
| //===----------------------------------------------------------------------===// |
| // MinNumFOp, MaxNumFOp |
| //===----------------------------------------------------------------------===// |
| |
| /// Converts arith.maxnumf/minnumf to spirv.GL.FMax/FMin or |
| /// spirv.CL.fmax/fmin. |
| template <typename Op, typename SPIRVOp> |
| class MinNumMaxNumFOpPattern final : public OpConversionPattern<Op> { |
| template <typename TargetOp> |
| constexpr bool shouldInsertNanGuards() const { |
| return llvm::is_one_of<TargetOp, spirv::GLFMaxOp, spirv::GLFMinOp>::value; |
| } |
| |
| public: |
| using OpConversionPattern<Op>::OpConversionPattern; |
| LogicalResult |
| matchAndRewrite(Op op, typename Op::Adaptor adaptor, |
| ConversionPatternRewriter &rewriter) const override { |
| auto *converter = this->template getTypeConverter<SPIRVTypeConverter>(); |
| Type dstType = converter->convertType(op.getType()); |
| if (!dstType) |
| return getTypeConversionFailure(rewriter, op); |
| |
| // arith.maxnumf/minnumf: |
| // "If one of the arguments is NaN, then the result is the other |
| // argument." |
| // spirv.GL.FMax/FMin |
| // "which operand is the result is undefined if one of the operands |
| // is a NaN." |
| // spirv.CL.fmax/fmin: |
| // "If one argument is a NaN, Fmin returns the other argument." |
| |
| Location loc = op.getLoc(); |
| Value spirvOp = |
| SPIRVOp::create(rewriter, loc, dstType, adaptor.getOperands()); |
| |
| if (!shouldInsertNanGuards<SPIRVOp>() || |
| bitEnumContainsAll(op.getFastmath(), arith::FastMathFlags::nnan)) { |
| rewriter.replaceOp(op, spirvOp); |
| return success(); |
| } |
| |
| Value lhsIsNan = spirv::IsNanOp::create(rewriter, loc, adaptor.getLhs()); |
| Value rhsIsNan = spirv::IsNanOp::create(rewriter, loc, adaptor.getRhs()); |
| |
| Value select1 = spirv::SelectOp::create(rewriter, loc, dstType, lhsIsNan, |
| adaptor.getRhs(), spirvOp); |
| Value select2 = spirv::SelectOp::create(rewriter, loc, dstType, rhsIsNan, |
| adaptor.getLhs(), select1); |
| |
| rewriter.replaceOp(op, select2); |
| return success(); |
| } |
| }; |
| |
| } // namespace |
| |
| //===----------------------------------------------------------------------===// |
| // Pattern Population |
| //===----------------------------------------------------------------------===// |
| |
| void mlir::arith::populateArithToSPIRVPatterns( |
| const SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns) { |
| // clang-format off |
| patterns.add< |
| ConstantCompositeOpPattern, |
| ConstantScalarOpPattern, |
| ElementwiseArithOpPattern<arith::AddIOp, spirv::IAddOp>, |
| ElementwiseArithOpPattern<arith::SubIOp, spirv::ISubOp>, |
| ElementwiseArithOpPattern<arith::MulIOp, spirv::IMulOp>, |
| spirv::ElementwiseOpPattern<arith::DivUIOp, spirv::UDivOp>, |
| spirv::ElementwiseOpPattern<arith::DivSIOp, spirv::SDivOp>, |
| spirv::ElementwiseOpPattern<arith::RemUIOp, spirv::UModOp>, |
| RemSIOpGLPattern, RemSIOpCLPattern, |
| BitwiseOpPattern<arith::AndIOp, spirv::LogicalAndOp, spirv::BitwiseAndOp>, |
| BitwiseOpPattern<arith::OrIOp, spirv::LogicalOrOp, spirv::BitwiseOrOp>, |
| XOrIOpLogicalPattern, XOrIOpBooleanPattern, |
| ElementwiseArithOpPattern<arith::ShLIOp, spirv::ShiftLeftLogicalOp>, |
| spirv::ElementwiseOpPattern<arith::ShRUIOp, spirv::ShiftRightLogicalOp>, |
| spirv::ElementwiseOpPattern<arith::ShRSIOp, spirv::ShiftRightArithmeticOp>, |
| spirv::ElementwiseOpPattern<arith::NegFOp, spirv::FNegateOp>, |
| spirv::ElementwiseOpPattern<arith::AddFOp, spirv::FAddOp>, |
| spirv::ElementwiseOpPattern<arith::SubFOp, spirv::FSubOp>, |
| spirv::ElementwiseOpPattern<arith::MulFOp, spirv::FMulOp>, |
| spirv::ElementwiseOpPattern<arith::DivFOp, spirv::FDivOp>, |
| spirv::ElementwiseOpPattern<arith::RemFOp, spirv::FRemOp>, |
| ExtUIPattern, ExtUII1Pattern, |
| ExtSIPattern, ExtSII1Pattern, |
| TypeCastingOpPattern<arith::ExtFOp, spirv::FConvertOp>, |
| TruncIPattern, TruncII1Pattern, |
| TypeCastingOpPattern<arith::TruncFOp, spirv::FConvertOp>, |
| TypeCastingOpPattern<arith::UIToFPOp, spirv::ConvertUToFOp>, UIToFPI1Pattern, |
| TypeCastingOpPattern<arith::SIToFPOp, spirv::ConvertSToFOp>, |
| TypeCastingOpPattern<arith::FPToUIOp, spirv::ConvertFToUOp>, |
| TypeCastingOpPattern<arith::FPToSIOp, spirv::ConvertFToSOp>, |
| TypeCastingOpPattern<arith::IndexCastOp, spirv::SConvertOp>, |
| IndexCastIndexI1Pattern, IndexCastI1IndexPattern, |
| TypeCastingOpPattern<arith::IndexCastUIOp, spirv::UConvertOp>, |
| TypeCastingOpPattern<arith::BitcastOp, spirv::BitcastOp>, |
| CmpIOpBooleanPattern, CmpIOpPattern, |
| CmpFOpNanNonePattern, CmpFOpPattern, |
| AddUIExtendedOpPattern, |
| MulIExtendedOpPattern<arith::MulSIExtendedOp, spirv::SMulExtendedOp>, |
| MulIExtendedOpPattern<arith::MulUIExtendedOp, spirv::UMulExtendedOp>, |
| SelectOpPattern, |
| |
| MinimumMaximumFOpPattern<arith::MaximumFOp, spirv::GLFMaxOp>, |
| MinimumMaximumFOpPattern<arith::MinimumFOp, spirv::GLFMinOp>, |
| MinNumMaxNumFOpPattern<arith::MaxNumFOp, spirv::GLFMaxOp>, |
| MinNumMaxNumFOpPattern<arith::MinNumFOp, spirv::GLFMinOp>, |
| spirv::ElementwiseOpPattern<arith::MaxSIOp, spirv::GLSMaxOp>, |
| spirv::ElementwiseOpPattern<arith::MaxUIOp, spirv::GLUMaxOp>, |
| spirv::ElementwiseOpPattern<arith::MinSIOp, spirv::GLSMinOp>, |
| spirv::ElementwiseOpPattern<arith::MinUIOp, spirv::GLUMinOp>, |
| |
| MinimumMaximumFOpPattern<arith::MaximumFOp, spirv::CLFMaxOp>, |
| MinimumMaximumFOpPattern<arith::MinimumFOp, spirv::CLFMinOp>, |
| MinNumMaxNumFOpPattern<arith::MaxNumFOp, spirv::CLFMaxOp>, |
| MinNumMaxNumFOpPattern<arith::MinNumFOp, spirv::CLFMinOp>, |
| spirv::ElementwiseOpPattern<arith::MaxSIOp, spirv::CLSMaxOp>, |
| spirv::ElementwiseOpPattern<arith::MaxUIOp, spirv::CLUMaxOp>, |
| spirv::ElementwiseOpPattern<arith::MinSIOp, spirv::CLSMinOp>, |
| spirv::ElementwiseOpPattern<arith::MinUIOp, spirv::CLUMinOp> |
| >(typeConverter, patterns.getContext()); |
| // clang-format on |
| |
| // Give CmpFOpNanKernelPattern a higher benefit so it can prevail when Kernel |
| // capability is available. |
| patterns.add<CmpFOpNanKernelPattern>(typeConverter, patterns.getContext(), |
| /*benefit=*/2); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // Pass Definition |
| //===----------------------------------------------------------------------===// |
| |
| namespace { |
| struct ConvertArithToSPIRVPass |
| : public impl::ConvertArithToSPIRVPassBase<ConvertArithToSPIRVPass> { |
| using Base::Base; |
| |
| void runOnOperation() override { |
| Operation *op = getOperation(); |
| spirv::TargetEnvAttr targetAttr = spirv::lookupTargetEnvOrDefault(op); |
| std::unique_ptr<SPIRVConversionTarget> target = |
| SPIRVConversionTarget::get(targetAttr); |
| |
| SPIRVConversionOptions options; |
| options.emulateLT32BitScalarTypes = this->emulateLT32BitScalarTypes; |
| options.emulateUnsupportedFloatTypes = this->emulateUnsupportedFloatTypes; |
| SPIRVTypeConverter typeConverter(targetAttr, options); |
| |
| // Use UnrealizedConversionCast as the bridge so that we don't need to pull |
| // in patterns for other dialects. |
| target->addLegalOp<UnrealizedConversionCastOp>(); |
| |
| // Fail hard when there are any remaining 'arith' ops. |
| target->addIllegalDialect<arith::ArithDialect>(); |
| |
| RewritePatternSet patterns(&getContext()); |
| arith::populateArithToSPIRVPatterns(typeConverter, patterns); |
| |
| if (failed(applyPartialConversion(op, *target, std::move(patterns)))) |
| signalPassFailure(); |
| } |
| }; |
| } // namespace |