[DemandedBits] Improve accuracy of Add propagator
The current demand propagator for addition will mark all input bits at and right of the alive output bit as alive. But carry won't propagate beyond a bit for which both operands are zero (or one/zero in the case of subtraction) so a more accurate answer is possible given known bits.
I derived a propagator by working through truth tables and using a bit-reversed addition to make demand ripple to the right, but I'm not sure how to make a convincing argument for its correctness in the comments yet. Nevertheless, here's a minimal implementation and test to get feedback.
This would help in a situation where, for example, four bytes (<128) packed into an int are added with four others SIMD-style but only one of the four results is actually read.
Known A: 0_______0_______0_______0_______
Known B: 0_______0_______0_______0_______
AOut: 00000000001000000000000000000000
AB, current: 00000000001111111111111111111111
AB, patch: 00000000001111111000000000000000
Committed on behalf of: @rrika (Erika)
Differential Revision: https://reviews.llvm.org/D72423
diff --git a/llvm/lib/Analysis/DemandedBits.cpp b/llvm/lib/Analysis/DemandedBits.cpp
index aaee8c2..62e08f3 100644
--- a/llvm/lib/Analysis/DemandedBits.cpp
+++ b/llvm/lib/Analysis/DemandedBits.cpp
@@ -173,7 +173,21 @@
}
break;
case Instruction::Add:
+ if (AOut.isMask()) {
+ AB = AOut;
+ } else {
+ ComputeKnownBits(BitWidth, UserI->getOperand(0), UserI->getOperand(1));
+ AB = determineLiveOperandBitsAdd(OperandNo, AOut, Known, Known2);
+ }
+ break;
case Instruction::Sub:
+ if (AOut.isMask()) {
+ AB = AOut;
+ } else {
+ ComputeKnownBits(BitWidth, UserI->getOperand(0), UserI->getOperand(1));
+ AB = determineLiveOperandBitsSub(OperandNo, AOut, Known, Known2);
+ }
+ break;
case Instruction::Mul:
// Find the highest live output bit. We don't need any more input
// bits than that (adds, and thus subtracts, ripple only to the
@@ -469,6 +483,86 @@
}
}
+static APInt determineLiveOperandBitsAddCarry(unsigned OperandNo,
+ const APInt &AOut,
+ const KnownBits &LHS,
+ const KnownBits &RHS,
+ bool CarryZero, bool CarryOne) {
+ assert(!(CarryZero && CarryOne) &&
+ "Carry can't be zero and one at the same time");
+
+ // The following check should be done by the caller, as it also indicates
+ // that LHS and RHS don't need to be computed.
+ //
+ // if (AOut.isMask())
+ // return AOut;
+
+ // Boundary bits' carry out is unaffected by their carry in.
+ APInt Bound = (LHS.Zero & RHS.Zero) | (LHS.One & RHS.One);
+
+ // First, the alive carry bits are determined from the alive output bits:
+ // Let demand ripple to the right but only up to any set bit in Bound.
+ // AOut = -1----
+ // Bound = ----1-
+ // ACarry&~AOut = --111-
+ APInt RBound = Bound.reverseBits();
+ APInt RAOut = AOut.reverseBits();
+ APInt RProp = RAOut + (RAOut | ~RBound);
+ APInt RACarry = RProp ^ ~RBound;
+ APInt ACarry = RACarry.reverseBits();
+
+ // Then, the alive input bits are determined from the alive carry bits:
+ APInt NeededToMaintainCarryZero;
+ APInt NeededToMaintainCarryOne;
+ if (OperandNo == 0) {
+ NeededToMaintainCarryZero = LHS.Zero | ~RHS.Zero;
+ NeededToMaintainCarryOne = LHS.One | ~RHS.One;
+ } else {
+ NeededToMaintainCarryZero = RHS.Zero | ~LHS.Zero;
+ NeededToMaintainCarryOne = RHS.One | ~LHS.One;
+ }
+
+ // As in computeForAddCarry
+ APInt PossibleSumZero = ~LHS.Zero + ~RHS.Zero + !CarryZero;
+ APInt PossibleSumOne = LHS.One + RHS.One + CarryOne;
+
+ // The below is simplified from
+ //
+ // APInt CarryKnownZero = ~(PossibleSumZero ^ LHS.Zero ^ RHS.Zero);
+ // APInt CarryKnownOne = PossibleSumOne ^ LHS.One ^ RHS.One;
+ // APInt CarryUnknown = ~(CarryKnownZero | CarryKnownOne);
+ //
+ // APInt NeededToMaintainCarry =
+ // (CarryKnownZero & NeededToMaintainCarryZero) |
+ // (CarryKnownOne & NeededToMaintainCarryOne) |
+ // CarryUnknown;
+
+ APInt NeededToMaintainCarry = (~PossibleSumZero | NeededToMaintainCarryZero) &
+ (PossibleSumOne | NeededToMaintainCarryOne);
+
+ APInt AB = AOut | (ACarry & NeededToMaintainCarry);
+ return AB;
+}
+
+APInt DemandedBits::determineLiveOperandBitsAdd(unsigned OperandNo,
+ const APInt &AOut,
+ const KnownBits &LHS,
+ const KnownBits &RHS) {
+ return determineLiveOperandBitsAddCarry(OperandNo, AOut, LHS, RHS, true,
+ false);
+}
+
+APInt DemandedBits::determineLiveOperandBitsSub(unsigned OperandNo,
+ const APInt &AOut,
+ const KnownBits &LHS,
+ const KnownBits &RHS) {
+ KnownBits NRHS;
+ NRHS.Zero = RHS.One;
+ NRHS.One = RHS.Zero;
+ return determineLiveOperandBitsAddCarry(OperandNo, AOut, LHS, NRHS, false,
+ true);
+}
+
FunctionPass *llvm::createDemandedBitsWrapperPass() {
return new DemandedBitsWrapperPass();
}