| //===- ScalarEvolutionDivision.h - See below --------------------*- C++ -*-===// |
| // |
| // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| // See https://llvm.org/LICENSE.txt for license information. |
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| // |
| //===----------------------------------------------------------------------===// |
| // |
| // This file defines the class that knows how to divide SCEV's. |
| // |
| //===----------------------------------------------------------------------===// |
| |
| #include "llvm/Analysis/ScalarEvolutionDivision.h" |
| #include "llvm/ADT/APInt.h" |
| #include "llvm/ADT/DenseMap.h" |
| #include "llvm/ADT/SmallVector.h" |
| #include "llvm/Analysis/ScalarEvolution.h" |
| #include "llvm/IR/Constants.h" |
| #include "llvm/Support/Casting.h" |
| #include "llvm/Support/ErrorHandling.h" |
| #include <cassert> |
| #include <cstdint> |
| |
| namespace llvm { |
| class Type; |
| } |
| |
| using namespace llvm; |
| |
| namespace { |
| |
| static inline int sizeOfSCEV(const SCEV *S) { |
| struct FindSCEVSize { |
| int Size = 0; |
| |
| FindSCEVSize() = default; |
| |
| bool follow(const SCEV *S) { |
| ++Size; |
| // Keep looking at all operands of S. |
| return true; |
| } |
| |
| bool isDone() const { return false; } |
| }; |
| |
| FindSCEVSize F; |
| SCEVTraversal<FindSCEVSize> ST(F); |
| ST.visitAll(S); |
| return F.Size; |
| } |
| |
| } // namespace |
| |
| // Computes the Quotient and Remainder of the division of Numerator by |
| // Denominator. |
| void SCEVDivision::divide(ScalarEvolution &SE, const SCEV *Numerator, |
| const SCEV *Denominator, const SCEV **Quotient, |
| const SCEV **Remainder) { |
| assert(Numerator && Denominator && "Uninitialized SCEV"); |
| |
| SCEVDivision D(SE, Numerator, Denominator); |
| |
| // Check for the trivial case here to avoid having to check for it in the |
| // rest of the code. |
| if (Numerator == Denominator) { |
| *Quotient = D.One; |
| *Remainder = D.Zero; |
| return; |
| } |
| |
| if (Numerator->isZero()) { |
| *Quotient = D.Zero; |
| *Remainder = D.Zero; |
| return; |
| } |
| |
| // A simple case when N/1. The quotient is N. |
| if (Denominator->isOne()) { |
| *Quotient = Numerator; |
| *Remainder = D.Zero; |
| return; |
| } |
| |
| // Split the Denominator when it is a product. |
| if (const SCEVMulExpr *T = dyn_cast<SCEVMulExpr>(Denominator)) { |
| const SCEV *Q, *R; |
| *Quotient = Numerator; |
| for (const SCEV *Op : T->operands()) { |
| divide(SE, *Quotient, Op, &Q, &R); |
| *Quotient = Q; |
| |
| // Bail out when the Numerator is not divisible by one of the terms of |
| // the Denominator. |
| if (!R->isZero()) { |
| *Quotient = D.Zero; |
| *Remainder = Numerator; |
| return; |
| } |
| } |
| *Remainder = D.Zero; |
| return; |
| } |
| |
| D.visit(Numerator); |
| *Quotient = D.Quotient; |
| *Remainder = D.Remainder; |
| } |
| |
| void SCEVDivision::visitConstant(const SCEVConstant *Numerator) { |
| if (const SCEVConstant *D = dyn_cast<SCEVConstant>(Denominator)) { |
| APInt NumeratorVal = Numerator->getAPInt(); |
| APInt DenominatorVal = D->getAPInt(); |
| uint32_t NumeratorBW = NumeratorVal.getBitWidth(); |
| uint32_t DenominatorBW = DenominatorVal.getBitWidth(); |
| |
| if (NumeratorBW > DenominatorBW) |
| DenominatorVal = DenominatorVal.sext(NumeratorBW); |
| else if (NumeratorBW < DenominatorBW) |
| NumeratorVal = NumeratorVal.sext(DenominatorBW); |
| |
| APInt QuotientVal(NumeratorVal.getBitWidth(), 0); |
| APInt RemainderVal(NumeratorVal.getBitWidth(), 0); |
| APInt::sdivrem(NumeratorVal, DenominatorVal, QuotientVal, RemainderVal); |
| Quotient = SE.getConstant(QuotientVal); |
| Remainder = SE.getConstant(RemainderVal); |
| return; |
| } |
| } |
| |
| void SCEVDivision::visitAddRecExpr(const SCEVAddRecExpr *Numerator) { |
| const SCEV *StartQ, *StartR, *StepQ, *StepR; |
| if (!Numerator->isAffine()) |
| return cannotDivide(Numerator); |
| divide(SE, Numerator->getStart(), Denominator, &StartQ, &StartR); |
| divide(SE, Numerator->getStepRecurrence(SE), Denominator, &StepQ, &StepR); |
| // Bail out if the types do not match. |
| Type *Ty = Denominator->getType(); |
| if (Ty != StartQ->getType() || Ty != StartR->getType() || |
| Ty != StepQ->getType() || Ty != StepR->getType()) |
| return cannotDivide(Numerator); |
| Quotient = SE.getAddRecExpr(StartQ, StepQ, Numerator->getLoop(), |
| Numerator->getNoWrapFlags()); |
| Remainder = SE.getAddRecExpr(StartR, StepR, Numerator->getLoop(), |
| Numerator->getNoWrapFlags()); |
| } |
| |
| void SCEVDivision::visitAddExpr(const SCEVAddExpr *Numerator) { |
| SmallVector<const SCEV *, 2> Qs, Rs; |
| Type *Ty = Denominator->getType(); |
| |
| for (const SCEV *Op : Numerator->operands()) { |
| const SCEV *Q, *R; |
| divide(SE, Op, Denominator, &Q, &R); |
| |
| // Bail out if types do not match. |
| if (Ty != Q->getType() || Ty != R->getType()) |
| return cannotDivide(Numerator); |
| |
| Qs.push_back(Q); |
| Rs.push_back(R); |
| } |
| |
| if (Qs.size() == 1) { |
| Quotient = Qs[0]; |
| Remainder = Rs[0]; |
| return; |
| } |
| |
| Quotient = SE.getAddExpr(Qs); |
| Remainder = SE.getAddExpr(Rs); |
| } |
| |
| void SCEVDivision::visitMulExpr(const SCEVMulExpr *Numerator) { |
| SmallVector<const SCEV *, 2> Qs; |
| Type *Ty = Denominator->getType(); |
| |
| bool FoundDenominatorTerm = false; |
| for (const SCEV *Op : Numerator->operands()) { |
| // Bail out if types do not match. |
| if (Ty != Op->getType()) |
| return cannotDivide(Numerator); |
| |
| if (FoundDenominatorTerm) { |
| Qs.push_back(Op); |
| continue; |
| } |
| |
| // Check whether Denominator divides one of the product operands. |
| const SCEV *Q, *R; |
| divide(SE, Op, Denominator, &Q, &R); |
| if (!R->isZero()) { |
| Qs.push_back(Op); |
| continue; |
| } |
| |
| // Bail out if types do not match. |
| if (Ty != Q->getType()) |
| return cannotDivide(Numerator); |
| |
| FoundDenominatorTerm = true; |
| Qs.push_back(Q); |
| } |
| |
| if (FoundDenominatorTerm) { |
| Remainder = Zero; |
| if (Qs.size() == 1) |
| Quotient = Qs[0]; |
| else |
| Quotient = SE.getMulExpr(Qs); |
| return; |
| } |
| |
| if (!isa<SCEVUnknown>(Denominator)) |
| return cannotDivide(Numerator); |
| |
| // The Remainder is obtained by replacing Denominator by 0 in Numerator. |
| ValueToSCEVMapTy RewriteMap; |
| RewriteMap[cast<SCEVUnknown>(Denominator)->getValue()] = Zero; |
| Remainder = SCEVParameterRewriter::rewrite(Numerator, SE, RewriteMap); |
| |
| if (Remainder->isZero()) { |
| // The Quotient is obtained by replacing Denominator by 1 in Numerator. |
| RewriteMap[cast<SCEVUnknown>(Denominator)->getValue()] = One; |
| Quotient = SCEVParameterRewriter::rewrite(Numerator, SE, RewriteMap); |
| return; |
| } |
| |
| // Quotient is (Numerator - Remainder) divided by Denominator. |
| const SCEV *Q, *R; |
| const SCEV *Diff = SE.getMinusSCEV(Numerator, Remainder); |
| // This SCEV does not seem to simplify: fail the division here. |
| if (sizeOfSCEV(Diff) > sizeOfSCEV(Numerator)) |
| return cannotDivide(Numerator); |
| divide(SE, Diff, Denominator, &Q, &R); |
| if (R != Zero) |
| return cannotDivide(Numerator); |
| Quotient = Q; |
| } |
| |
| SCEVDivision::SCEVDivision(ScalarEvolution &S, const SCEV *Numerator, |
| const SCEV *Denominator) |
| : SE(S), Denominator(Denominator) { |
| Zero = SE.getZero(Denominator->getType()); |
| One = SE.getOne(Denominator->getType()); |
| |
| // We generally do not know how to divide Expr by Denominator. We initialize |
| // the division to a "cannot divide" state to simplify the rest of the code. |
| cannotDivide(Numerator); |
| } |
| |
| // Convenience function for giving up on the division. We set the quotient to |
| // be equal to zero and the remainder to be equal to the numerator. |
| void SCEVDivision::cannotDivide(const SCEV *Numerator) { |
| Quotient = Zero; |
| Remainder = Numerator; |
| } |