| //===- MathToSPIRV.cpp - Math to SPIR-V Patterns --------------------------===// |
| // |
| // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| // See https://llvm.org/LICENSE.txt for license information. |
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| // |
| //===----------------------------------------------------------------------===// |
| // |
| // This file implements patterns to convert Math dialect to SPIR-V dialect. |
| // |
| //===----------------------------------------------------------------------===// |
| |
| #include "../SPIRVCommon/Pattern.h" |
| #include "mlir/Dialect/Math/IR/Math.h" |
| #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" |
| #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h" |
| #include "mlir/IR/BuiltinTypes.h" |
| #include "mlir/IR/Matchers.h" |
| #include "mlir/IR/TypeUtilities.h" |
| #include "mlir/Transforms/DialectConversion.h" |
| #include "llvm/ADT/STLExtras.h" |
| #include "llvm/ADT/TypeSwitch.h" |
| #include "llvm/Support/FormatVariadic.h" |
| |
| #define DEBUG_TYPE "math-to-spirv-pattern" |
| |
| using namespace mlir; |
| |
| //===----------------------------------------------------------------------===// |
| // Utility functions |
| //===----------------------------------------------------------------------===// |
| |
| /// Creates a 32-bit scalar/vector integer constant. Returns nullptr if the |
| /// given type is not a 32-bit scalar/vector type. |
| static Value getScalarOrVectorI32Constant(Type type, int value, |
| OpBuilder &builder, Location loc) { |
| if (auto vectorType = dyn_cast<VectorType>(type)) { |
| if (!vectorType.getElementType().isInteger(32)) |
| return nullptr; |
| SmallVector<int> values(vectorType.getNumElements(), value); |
| return spirv::ConstantOp::create(builder, loc, type, |
| builder.getI32VectorAttr(values)); |
| } |
| if (type.isInteger(32)) |
| return spirv::ConstantOp::create(builder, loc, type, |
| builder.getI32IntegerAttr(value)); |
| |
| return nullptr; |
| } |
| |
| /// Check if the type is supported by math-to-spirv conversion. We expect to |
| /// only see scalars and vectors at this point, with higher-level types already |
| /// lowered. |
| static bool isSupportedSourceType(Type originalType) { |
| if (originalType.isIntOrIndexOrFloat()) |
| return true; |
| |
| if (auto vecTy = dyn_cast<VectorType>(originalType)) { |
| if (!vecTy.getElementType().isIntOrIndexOrFloat()) |
| return false; |
| if (vecTy.isScalable()) |
| return false; |
| if (vecTy.getRank() > 1) |
| return false; |
| |
| return true; |
| } |
| |
| return false; |
| } |
| |
| /// Check if all `sourceOp` types are supported by math-to-spirv conversion. |
| /// Notify of a match failure othwerise and return a `failure` result. |
| /// This is intended to simplify type checks in `OpConversionPattern`s. |
| static LogicalResult checkSourceOpTypes(ConversionPatternRewriter &rewriter, |
| Operation *sourceOp) { |
| auto allTypes = llvm::to_vector(sourceOp->getOperandTypes()); |
| llvm::append_range(allTypes, sourceOp->getResultTypes()); |
| |
| for (Type ty : allTypes) { |
| if (!isSupportedSourceType(ty)) { |
| return rewriter.notifyMatchFailure( |
| sourceOp, |
| llvm::formatv( |
| "unsupported source type for Math to SPIR-V conversion: {0}", |
| ty)); |
| } |
| } |
| |
| return success(); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // Operation conversion |
| //===----------------------------------------------------------------------===// |
| |
| // Note that DRR cannot be used for the patterns in this file: we may need to |
| // convert type along the way, which requires ConversionPattern. DRR generates |
| // normal RewritePattern. |
| |
| namespace { |
| /// Converts elementwise unary, binary, and ternary standard operations to |
| /// SPIR-V operations. Checks that source `Op` types are supported. |
| template <typename Op, typename SPIRVOp> |
| struct CheckedElementwiseOpPattern final |
| : public spirv::ElementwiseOpPattern<Op, SPIRVOp> { |
| using BasePattern = typename spirv::ElementwiseOpPattern<Op, SPIRVOp>; |
| using BasePattern::BasePattern; |
| |
| LogicalResult |
| matchAndRewrite(Op op, typename Op::Adaptor adaptor, |
| ConversionPatternRewriter &rewriter) const override { |
| if (LogicalResult res = checkSourceOpTypes(rewriter, op); failed(res)) |
| return res; |
| |
| return BasePattern::matchAndRewrite(op, adaptor, rewriter); |
| } |
| }; |
| |
| /// Converts math.copysign to SPIR-V ops. |
| struct CopySignPattern final : public OpConversionPattern<math::CopySignOp> { |
| using Base::Base; |
| |
| LogicalResult |
| matchAndRewrite(math::CopySignOp copySignOp, OpAdaptor adaptor, |
| ConversionPatternRewriter &rewriter) const override { |
| if (LogicalResult res = checkSourceOpTypes(rewriter, copySignOp); |
| failed(res)) |
| return res; |
| |
| Type type = getTypeConverter()->convertType(copySignOp.getType()); |
| if (!type) |
| return failure(); |
| |
| FloatType floatType; |
| if (auto scalarType = dyn_cast<FloatType>(copySignOp.getType())) { |
| floatType = scalarType; |
| } else if (auto vectorType = dyn_cast<VectorType>(copySignOp.getType())) { |
| floatType = cast<FloatType>(vectorType.getElementType()); |
| } else { |
| return failure(); |
| } |
| |
| Location loc = copySignOp.getLoc(); |
| int bitwidth = floatType.getWidth(); |
| Type intType = rewriter.getIntegerType(bitwidth); |
| uint64_t intValue = uint64_t(1) << (bitwidth - 1); |
| |
| Value signMask = spirv::ConstantOp::create( |
| rewriter, loc, intType, rewriter.getIntegerAttr(intType, intValue)); |
| Value valueMask = spirv::ConstantOp::create( |
| rewriter, loc, intType, |
| rewriter.getIntegerAttr(intType, intValue - 1u)); |
| |
| if (auto vectorType = dyn_cast<VectorType>(type)) { |
| assert(vectorType.getRank() == 1); |
| int count = vectorType.getNumElements(); |
| intType = VectorType::get(count, intType); |
| |
| Repeated<Value> signSplat(count, signMask); |
| signMask = spirv::CompositeConstructOp::create(rewriter, loc, intType, |
| signSplat); |
| |
| Repeated<Value> valueSplat(count, valueMask); |
| valueMask = spirv::CompositeConstructOp::create(rewriter, loc, intType, |
| valueSplat); |
| } |
| |
| Value lhsCast = |
| spirv::BitcastOp::create(rewriter, loc, intType, adaptor.getLhs()); |
| Value rhsCast = |
| spirv::BitcastOp::create(rewriter, loc, intType, adaptor.getRhs()); |
| |
| Value value = spirv::BitwiseAndOp::create(rewriter, loc, intType, |
| ValueRange{lhsCast, valueMask}); |
| Value sign = spirv::BitwiseAndOp::create(rewriter, loc, intType, |
| ValueRange{rhsCast, signMask}); |
| |
| Value result = spirv::BitwiseOrOp::create(rewriter, loc, intType, |
| ValueRange{value, sign}); |
| rewriter.replaceOpWithNewOp<spirv::BitcastOp>(copySignOp, type, result); |
| return success(); |
| } |
| }; |
| |
| /// Converts math.ctlz to SPIR-V ops. |
| /// |
| /// OpenCL targets lower math.ctlz directly to OpenCL.std clz via the generic |
| /// elementwise pattern. This pattern handles the shader fallback. |
| /// |
| /// SPIR-V does not have a direct operations for counting leading zeros for |
| /// glsl. If Shader capability is supported, we can leverage GL FindUMsb to |
| /// calculate it. |
| struct CountLeadingZerosPattern final |
| : public OpConversionPattern<math::CountLeadingZerosOp> { |
| using Base::Base; |
| |
| LogicalResult |
| matchAndRewrite(math::CountLeadingZerosOp countOp, OpAdaptor adaptor, |
| ConversionPatternRewriter &rewriter) const override { |
| if (LogicalResult res = checkSourceOpTypes(rewriter, countOp); failed(res)) |
| return res; |
| |
| Type type = getTypeConverter()->convertType(countOp.getType()); |
| if (!type) |
| return failure(); |
| |
| auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>(); |
| if (!typeConverter.getTargetEnv().allows(spirv::Capability::Shader)) |
| return rewriter.notifyMatchFailure(countOp, "requires Shader capability"); |
| |
| // The GL FindUMsb fallback only supports 32-bit integer types for now. |
| unsigned bitwidth = 0; |
| if (isa<IntegerType>(type)) |
| bitwidth = type.getIntOrFloatBitWidth(); |
| if (auto vectorType = dyn_cast<VectorType>(type)) |
| bitwidth = vectorType.getElementTypeBitWidth(); |
| if (bitwidth != 32) |
| return failure(); |
| |
| Location loc = countOp.getLoc(); |
| Value input = adaptor.getOperand(); |
| Value val1 = getScalarOrVectorI32Constant(type, 1, rewriter, loc); |
| Value val31 = getScalarOrVectorI32Constant(type, 31, rewriter, loc); |
| Value val32 = getScalarOrVectorI32Constant(type, 32, rewriter, loc); |
| |
| Value msb = spirv::GLFindUMsbOp::create(rewriter, loc, input); |
| // We need to subtract from 31 given that the index returned by GLSL |
| // FindUMsb is counted from the least significant bit. Theoretically this |
| // also gives the correct result even if the integer has all zero bits, in |
| // which case GL FindUMsb would return -1. |
| Value subMsb = spirv::ISubOp::create(rewriter, loc, val31, msb); |
| // However, certain Vulkan implementations have driver bugs for the corner |
| // case where the input is zero. And.. it can be smart to optimize a select |
| // only involving the corner case. So separately compute the result when the |
| // input is either zero or one. |
| Value subInput = spirv::ISubOp::create(rewriter, loc, val32, input); |
| Value cmp = spirv::ULessThanEqualOp::create(rewriter, loc, input, val1); |
| rewriter.replaceOpWithNewOp<spirv::SelectOp>(countOp, cmp, subInput, |
| subMsb); |
| return success(); |
| } |
| }; |
| |
| /// Converts math.cttz to GL FindILsb. GL FindILsb returns -1 for a zero |
| /// input while math.cttz must return the bitwidth, so the zero case is |
| /// patched up with a select. |
| struct CountTrailingZerosPattern final |
| : public OpConversionPattern<math::CountTrailingZerosOp> { |
| using Base::Base; |
| |
| LogicalResult |
| matchAndRewrite(math::CountTrailingZerosOp countOp, OpAdaptor adaptor, |
| ConversionPatternRewriter &rewriter) const override { |
| if (LogicalResult res = checkSourceOpTypes(rewriter, countOp); failed(res)) |
| return res; |
| |
| Type type = getTypeConverter()->convertType(countOp.getType()); |
| if (!type) |
| return failure(); |
| |
| unsigned bitwidth = 0; |
| if (isa<IntegerType>(type)) |
| bitwidth = type.getIntOrFloatBitWidth(); |
| else if (auto vectorType = dyn_cast<VectorType>(type)) |
| bitwidth = vectorType.getElementTypeBitWidth(); |
| if (bitwidth != 32) |
| return failure(); |
| |
| Location loc = countOp.getLoc(); |
| Value input = adaptor.getOperand(); |
| Value val0 = getScalarOrVectorI32Constant(type, 0, rewriter, loc); |
| Value valBitwidth = |
| getScalarOrVectorI32Constant(type, bitwidth, rewriter, loc); |
| |
| Value lsb = spirv::GLFindILsbOp::create(rewriter, loc, input); |
| Value isZero = spirv::IEqualOp::create(rewriter, loc, input, val0); |
| rewriter.replaceOpWithNewOp<spirv::SelectOp>(countOp, isZero, valBitwidth, |
| lsb); |
| return success(); |
| } |
| }; |
| |
| /// Converts math.expm1 to SPIR-V ops. |
| /// |
| /// SPIR-V does not have a direct operations for exp(x)-1. Explicitly lower to |
| /// these operations. |
| template <typename ExpOp> |
| struct ExpM1OpPattern final : public OpConversionPattern<math::ExpM1Op> { |
| using Base::Base; |
| |
| LogicalResult |
| matchAndRewrite(math::ExpM1Op operation, OpAdaptor adaptor, |
| ConversionPatternRewriter &rewriter) const override { |
| assert(adaptor.getOperands().size() == 1); |
| if (LogicalResult res = checkSourceOpTypes(rewriter, operation); |
| failed(res)) |
| return res; |
| |
| Location loc = operation.getLoc(); |
| Type type = this->getTypeConverter()->convertType(operation.getType()); |
| if (!type) |
| return failure(); |
| |
| Value exp = ExpOp::create(rewriter, loc, type, adaptor.getOperand()); |
| auto one = spirv::ConstantOp::getOne(type, loc, rewriter); |
| rewriter.replaceOpWithNewOp<spirv::FSubOp>(operation, exp, one); |
| return success(); |
| } |
| }; |
| |
| /// Converts math.log1p to SPIR-V ops. |
| /// |
| /// SPIR-V does not have a direct operations for log(1+x). Explicitly lower to |
| /// these operations. |
| template <typename LogOp> |
| struct Log1pOpPattern final : public OpConversionPattern<math::Log1pOp> { |
| using Base::Base; |
| |
| LogicalResult |
| matchAndRewrite(math::Log1pOp operation, OpAdaptor adaptor, |
| ConversionPatternRewriter &rewriter) const override { |
| assert(adaptor.getOperands().size() == 1); |
| if (LogicalResult res = checkSourceOpTypes(rewriter, operation); |
| failed(res)) |
| return res; |
| |
| Location loc = operation.getLoc(); |
| Type type = this->getTypeConverter()->convertType(operation.getType()); |
| if (!type) |
| return failure(); |
| |
| auto one = spirv::ConstantOp::getOne(type, operation.getLoc(), rewriter); |
| Value onePlus = |
| spirv::FAddOp::create(rewriter, loc, one, adaptor.getOperand()); |
| rewriter.replaceOpWithNewOp<LogOp>(operation, type, onePlus); |
| return success(); |
| } |
| }; |
| |
| /// Converts math.log10 to GLSL SPIR-V ops. |
| /// |
| /// GLSL.std.450 has no Log10 instruction. Lower it as: |
| /// log10(x) = log(x) * 1/log(10) |
| struct Log10OpPattern final : public OpConversionPattern<math::Log10Op> { |
| using Base::Base; |
| |
| static constexpr double log10Reciprocal = |
| 0.4342944819032518276511289189166050822943970058036665661144537832; |
| |
| LogicalResult |
| matchAndRewrite(math::Log10Op operation, OpAdaptor adaptor, |
| ConversionPatternRewriter &rewriter) const override { |
| assert(adaptor.getOperands().size() == 1); |
| if (LogicalResult res = checkSourceOpTypes(rewriter, operation); |
| failed(res)) |
| return res; |
| |
| Location loc = operation.getLoc(); |
| Type type = this->getTypeConverter()->convertType(operation.getType()); |
| if (!type) |
| return rewriter.notifyMatchFailure(operation, "type conversion failed"); |
| |
| auto getConstantValue = [&](double value) { |
| if (auto floatType = dyn_cast<FloatType>(type)) { |
| return spirv::ConstantOp::create( |
| rewriter, loc, type, rewriter.getFloatAttr(floatType, value)); |
| } |
| if (auto vectorType = dyn_cast<VectorType>(type)) { |
| Type elemType = vectorType.getElementType(); |
| |
| if (isa<FloatType>(elemType)) { |
| return spirv::ConstantOp::create( |
| rewriter, loc, type, |
| DenseFPElementsAttr::get( |
| vectorType, FloatAttr::get(elemType, value).getValue())); |
| } |
| } |
| llvm_unreachable("unimplemented type for log10"); |
| }; |
| |
| Value constantValue = getConstantValue(log10Reciprocal); |
| Value log = spirv::GLLogOp::create(rewriter, loc, adaptor.getOperand()); |
| rewriter.replaceOpWithNewOp<spirv::FMulOp>(operation, type, log, |
| constantValue); |
| return success(); |
| } |
| }; |
| |
| /// Converts math.powf to SPIRV-Ops. |
| struct PowFOpPattern final : public OpConversionPattern<math::PowFOp> { |
| using Base::Base; |
| |
| LogicalResult |
| matchAndRewrite(math::PowFOp powfOp, OpAdaptor adaptor, |
| ConversionPatternRewriter &rewriter) const override { |
| if (LogicalResult res = checkSourceOpTypes(rewriter, powfOp); failed(res)) |
| return res; |
| |
| Type dstType = getTypeConverter()->convertType(powfOp.getType()); |
| if (!dstType) |
| return failure(); |
| |
| Location loc = powfOp.getLoc(); |
| Type operandType = adaptor.getRhs().getType(); |
| |
| // Parity-based lowering requires an integer-valued constant exponent. |
| // Otherwise fall back to exp(y*log(x)), which yields NaN for x<0 (matches |
| // C). |
| auto isOdd = [](const APFloat &v) { |
| APSInt i(/*BitWidth=*/64, /*isUnsigned=*/false); |
| bool ignored; |
| v.convertToInteger(i, APFloat::rmTowardZero, &ignored); |
| return i[0]; |
| }; |
| |
| SmallVector<bool> oddMask; |
| Attribute rhsAttr; |
| if (matchPattern(adaptor.getRhs(), m_Constant(&rhsAttr))) { |
| TypeSwitch<Attribute>(rhsAttr) |
| .Case([&](FloatAttr a) { |
| if (a.getValue().isInteger()) |
| oddMask.push_back(isOdd(a.getValue())); |
| }) |
| .Case([&](SplatElementsAttr a) { |
| APFloat splat = a.getSplatValue<APFloat>(); |
| if (splat.isInteger()) |
| oddMask.push_back(isOdd(splat)); |
| }) |
| .Case([&](DenseElementsAttr a) { |
| SmallVector<bool> mask; |
| for (const APFloat &elt : a.getValues<APFloat>()) { |
| if (!elt.isInteger()) |
| return; |
| mask.push_back(isOdd(elt)); |
| } |
| oddMask = std::move(mask); |
| }); |
| } |
| |
| if (oddMask.empty()) { |
| Value log = spirv::GLLogOp::create(rewriter, loc, adaptor.getLhs()); |
| Value mul = spirv::FMulOp::create(rewriter, loc, adaptor.getRhs(), log); |
| rewriter.replaceOpWithNewOp<spirv::GLExpOp>(powfOp, mul); |
| return success(); |
| } |
| |
| // GL.Pow is undefined for x < 0; take abs and conditionally negate the |
| // result for lanes whose exponent is odd. |
| Value abs = spirv::GLFAbsOp::create(rewriter, loc, adaptor.getLhs()); |
| Value pow = spirv::GLPowOp::create(rewriter, loc, abs, adaptor.getRhs()); |
| |
| // No odd-parity element: result has the same sign as |lhs|^rhs >= 0. |
| if (llvm::none_of(oddMask, [](bool b) { return b; })) { |
| rewriter.replaceOp(powfOp, pow); |
| return success(); |
| } |
| |
| Value zero = spirv::ConstantOp::getZero(operandType, loc, rewriter); |
| Value lessThan = |
| spirv::FOrdLessThanOp::create(rewriter, loc, adaptor.getLhs(), zero); |
| Value negate = spirv::FNegateOp::create(rewriter, loc, pow); |
| |
| Value shouldNegate; |
| if (llvm::all_equal(oddMask)) { |
| // Every lane has odd exponent: negate iff lhs < 0. |
| shouldNegate = lessThan; |
| } else { |
| // Mixed parity (non-splat dense vector): AND lhs<0 with a per-element |
| // constant odd-mask. |
| auto vecType = cast<VectorType>(operandType); |
| auto maskType = VectorType::get(vecType.getShape(), rewriter.getI1Type()); |
| Value oddConst = spirv::ConstantOp::create( |
| rewriter, loc, maskType, DenseElementsAttr::get(maskType, oddMask)); |
| shouldNegate = |
| spirv::LogicalAndOp::create(rewriter, loc, lessThan, oddConst); |
| } |
| |
| rewriter.replaceOpWithNewOp<spirv::SelectOp>(powfOp, shouldNegate, negate, |
| pow); |
| return success(); |
| } |
| }; |
| |
| /// Converts math.fpowi to spirv.CL.pown. |
| struct PowIOpPattern final : public OpConversionPattern<math::FPowIOp> { |
| using Base::Base; |
| |
| LogicalResult |
| matchAndRewrite(math::FPowIOp op, OpAdaptor adaptor, |
| ConversionPatternRewriter &rewriter) const override { |
| if (LogicalResult res = checkSourceOpTypes(rewriter, op); failed(res)) |
| return res; |
| |
| Type dstType = getTypeConverter()->convertType(op.getType()); |
| if (!dstType) |
| return failure(); |
| |
| rewriter.replaceOpWithNewOp<spirv::CLPownOp>(op, dstType, adaptor.getLhs(), |
| adaptor.getRhs()); |
| return success(); |
| } |
| }; |
| |
| /// Converts math.fpowi to GLSL SPIR-V ops. GL has no integer-power op, so the |
| /// exponent is converted to float and lowered through spirv.GL.Pow. As GL.Pow |
| /// is undefined for a negative base, the base is made positive and the result |
| /// is negated when the base is negative and the exponent is odd. |
| struct PowIOpGLPattern final : public OpConversionPattern<math::FPowIOp> { |
| using Base::Base; |
| |
| LogicalResult |
| matchAndRewrite(math::FPowIOp op, OpAdaptor adaptor, |
| ConversionPatternRewriter &rewriter) const override { |
| if (LogicalResult res = checkSourceOpTypes(rewriter, op); failed(res)) |
| return res; |
| |
| Type dstType = getTypeConverter()->convertType(op.getType()); |
| if (!dstType) |
| return failure(); |
| |
| Location loc = op.getLoc(); |
| Value base = adaptor.getLhs(); |
| Value power = adaptor.getRhs(); |
| |
| Value expFloat = |
| spirv::ConvertSToFOp::create(rewriter, loc, dstType, power); |
| Value abs = spirv::GLFAbsOp::create(rewriter, loc, base); |
| Value pow = spirv::GLPowOp::create(rewriter, loc, abs, expFloat); |
| |
| Value zeroF = spirv::ConstantOp::getZero(dstType, loc, rewriter); |
| Value lessThan = spirv::FOrdLessThanOp::create(rewriter, loc, base, zeroF); |
| |
| Type powerType = power.getType(); |
| Value oneI = spirv::ConstantOp::getOne(powerType, loc, rewriter); |
| Value lowBit = spirv::BitwiseAndOp::create(rewriter, loc, power, oneI); |
| Value isOdd = spirv::IEqualOp::create(rewriter, loc, lowBit, oneI); |
| |
| Value shouldNegate = |
| spirv::LogicalAndOp::create(rewriter, loc, lessThan, isOdd); |
| Value negate = spirv::FNegateOp::create(rewriter, loc, pow); |
| rewriter.replaceOpWithNewOp<spirv::SelectOp>(op, shouldNegate, negate, pow); |
| return success(); |
| } |
| }; |
| |
| /// Converts math.round to GLSL SPIRV extended ops. |
| struct RoundOpPattern final : public OpConversionPattern<math::RoundOp> { |
| using Base::Base; |
| |
| LogicalResult |
| matchAndRewrite(math::RoundOp roundOp, OpAdaptor adaptor, |
| ConversionPatternRewriter &rewriter) const override { |
| if (LogicalResult res = checkSourceOpTypes(rewriter, roundOp); failed(res)) |
| return res; |
| |
| Location loc = roundOp.getLoc(); |
| auto ty = getTypeConverter()->convertType(adaptor.getOperand().getType()); |
| if (!ty) { |
| return rewriter.notifyMatchFailure( |
| roundOp->getLoc(), |
| llvm::formatv("failed to convert type {0} for SPIR-V", |
| roundOp.getType())); |
| } |
| |
| Type ety = getElementTypeOrSelf(ty); |
| |
| auto zero = spirv::ConstantOp::getZero(ty, loc, rewriter); |
| auto one = spirv::ConstantOp::getOne(ty, loc, rewriter); |
| Value half; |
| if (VectorType vty = dyn_cast<VectorType>(ty)) { |
| half = spirv::ConstantOp::create( |
| rewriter, loc, vty, |
| DenseElementsAttr::get(vty, |
| rewriter.getFloatAttr(ety, 0.5).getValue())); |
| } else { |
| half = spirv::ConstantOp::create(rewriter, loc, ty, |
| rewriter.getFloatAttr(ety, 0.5)); |
| } |
| |
| auto abs = spirv::GLFAbsOp::create(rewriter, loc, adaptor.getOperand()); |
| auto floor = spirv::GLFloorOp::create(rewriter, loc, abs); |
| auto sub = spirv::FSubOp::create(rewriter, loc, abs, floor); |
| auto greater = |
| spirv::FOrdGreaterThanEqualOp::create(rewriter, loc, sub, half); |
| auto select = spirv::SelectOp::create(rewriter, loc, greater, one, zero); |
| auto add = spirv::FAddOp::create(rewriter, loc, floor, select); |
| rewriter.replaceOpWithNewOp<math::CopySignOp>(roundOp, add, |
| adaptor.getOperand()); |
| return success(); |
| } |
| }; |
| |
| } // namespace |
| |
| //===----------------------------------------------------------------------===// |
| // Pattern population |
| //===----------------------------------------------------------------------===// |
| |
| namespace mlir { |
| void populateMathToSPIRVPatterns(const SPIRVTypeConverter &typeConverter, |
| RewritePatternSet &patterns) { |
| // Core patterns |
| patterns |
| .add<CopySignPattern, |
| CheckedElementwiseOpPattern<math::CtPopOp, spirv::BitCountOp>, |
| CheckedElementwiseOpPattern<math::IsInfOp, spirv::IsInfOp>, |
| CheckedElementwiseOpPattern<math::IsNaNOp, spirv::IsNanOp>, |
| CheckedElementwiseOpPattern<math::IsFiniteOp, spirv::IsFiniteOp>, |
| CheckedElementwiseOpPattern<math::IsNormalOp, spirv::IsNormalOp>>( |
| typeConverter, patterns.getContext()); |
| |
| // GLSL patterns |
| patterns |
| .add<CountLeadingZerosPattern, CountTrailingZerosPattern, |
| Log1pOpPattern<spirv::GLLogOp>, Log10OpPattern, |
| ExpM1OpPattern<spirv::GLExpOp>, PowFOpPattern, PowIOpGLPattern, |
| RoundOpPattern, |
| CheckedElementwiseOpPattern<math::AbsFOp, spirv::GLFAbsOp>, |
| CheckedElementwiseOpPattern<math::AbsIOp, spirv::GLSAbsOp>, |
| CheckedElementwiseOpPattern<math::AtanOp, spirv::GLAtanOp>, |
| CheckedElementwiseOpPattern<math::CeilOp, spirv::GLCeilOp>, |
| CheckedElementwiseOpPattern<math::ClampFOp, spirv::GLFClampOp>, |
| CheckedElementwiseOpPattern<math::CosOp, spirv::GLCosOp>, |
| CheckedElementwiseOpPattern<math::ExpOp, spirv::GLExpOp>, |
| CheckedElementwiseOpPattern<math::Exp2Op, spirv::GLExp2Op>, |
| CheckedElementwiseOpPattern<math::FloorOp, spirv::GLFloorOp>, |
| CheckedElementwiseOpPattern<math::FmaOp, spirv::GLFmaOp>, |
| CheckedElementwiseOpPattern<math::LogOp, spirv::GLLogOp>, |
| CheckedElementwiseOpPattern<math::Log2Op, spirv::GLLog2Op>, |
| CheckedElementwiseOpPattern<math::RoundEvenOp, spirv::GLRoundEvenOp>, |
| CheckedElementwiseOpPattern<math::RsqrtOp, spirv::GLInverseSqrtOp>, |
| CheckedElementwiseOpPattern<math::SinOp, spirv::GLSinOp>, |
| CheckedElementwiseOpPattern<math::SqrtOp, spirv::GLSqrtOp>, |
| CheckedElementwiseOpPattern<math::TanhOp, spirv::GLTanhOp>, |
| CheckedElementwiseOpPattern<math::TanOp, spirv::GLTanOp>, |
| CheckedElementwiseOpPattern<math::TruncOp, spirv::GLTruncOp>, |
| CheckedElementwiseOpPattern<math::AsinOp, spirv::GLAsinOp>, |
| CheckedElementwiseOpPattern<math::AcosOp, spirv::GLAcosOp>, |
| CheckedElementwiseOpPattern<math::SinhOp, spirv::GLSinhOp>, |
| CheckedElementwiseOpPattern<math::CoshOp, spirv::GLCoshOp>, |
| CheckedElementwiseOpPattern<math::AsinhOp, spirv::GLAsinhOp>, |
| CheckedElementwiseOpPattern<math::AcoshOp, spirv::GLAcoshOp>, |
| CheckedElementwiseOpPattern<math::AtanhOp, spirv::GLAtanhOp>>( |
| typeConverter, patterns.getContext()); |
| |
| // OpenCL patterns |
| patterns.add< |
| Log1pOpPattern<spirv::CLLogOp>, ExpM1OpPattern<spirv::CLExpOp>, |
| CheckedElementwiseOpPattern<math::AbsFOp, spirv::CLFAbsOp>, |
| CheckedElementwiseOpPattern<math::AbsIOp, spirv::CLSAbsOp>, |
| CheckedElementwiseOpPattern<math::CountLeadingZerosOp, spirv::CLClzOp>, |
| CheckedElementwiseOpPattern<math::AtanOp, spirv::CLAtanOp>, |
| CheckedElementwiseOpPattern<math::Atan2Op, spirv::CLAtan2Op>, |
| CheckedElementwiseOpPattern<math::CeilOp, spirv::CLCeilOp>, |
| CheckedElementwiseOpPattern<math::CosOp, spirv::CLCosOp>, |
| CheckedElementwiseOpPattern<math::ErfOp, spirv::CLErfOp>, |
| CheckedElementwiseOpPattern<math::ExpOp, spirv::CLExpOp>, |
| CheckedElementwiseOpPattern<math::Exp2Op, spirv::CLExp2Op>, |
| CheckedElementwiseOpPattern<math::FloorOp, spirv::CLFloorOp>, |
| CheckedElementwiseOpPattern<math::FmaOp, spirv::CLFmaOp>, |
| CheckedElementwiseOpPattern<math::LogOp, spirv::CLLogOp>, |
| CheckedElementwiseOpPattern<math::Log2Op, spirv::CLLog2Op>, |
| CheckedElementwiseOpPattern<math::Log10Op, spirv::CLLog10Op>, |
| CheckedElementwiseOpPattern<math::PowFOp, spirv::CLPowOp>, PowIOpPattern, |
| CheckedElementwiseOpPattern<math::RoundEvenOp, spirv::CLRintOp>, |
| CheckedElementwiseOpPattern<math::RoundOp, spirv::CLRoundOp>, |
| CheckedElementwiseOpPattern<math::RsqrtOp, spirv::CLRsqrtOp>, |
| CheckedElementwiseOpPattern<math::SinOp, spirv::CLSinOp>, |
| CheckedElementwiseOpPattern<math::SqrtOp, spirv::CLSqrtOp>, |
| CheckedElementwiseOpPattern<math::TanhOp, spirv::CLTanhOp>, |
| CheckedElementwiseOpPattern<math::TanOp, spirv::CLTanOp>, |
| CheckedElementwiseOpPattern<math::TruncOp, spirv::CLTruncOp>, |
| CheckedElementwiseOpPattern<math::AsinOp, spirv::CLAsinOp>, |
| CheckedElementwiseOpPattern<math::AcosOp, spirv::CLAcosOp>, |
| CheckedElementwiseOpPattern<math::SinhOp, spirv::CLSinhOp>, |
| CheckedElementwiseOpPattern<math::CoshOp, spirv::CLCoshOp>, |
| CheckedElementwiseOpPattern<math::AsinhOp, spirv::CLAsinhOp>, |
| CheckedElementwiseOpPattern<math::AcoshOp, spirv::CLAcoshOp>, |
| CheckedElementwiseOpPattern<math::AtanhOp, spirv::CLAtanhOp>>( |
| typeConverter, patterns.getContext()); |
| } |
| |
| } // namespace mlir |