"""Utilities for enumeration of finite and countably infinite sets.
"""
###
# Countable iteration

# Simplifies some calculations
class Aleph0(int):
    _singleton = None
    def __new__(type):
        if type._singleton is None:
            type._singleton = int.__new__(type)
        return type._singleton
    def __repr__(self): return '<aleph0>'
    def __str__(self): return 'inf'
    
    def __cmp__(self, b):
        return 1

    def __sub__(self, b):
        raise ValueError("Cannot subtract aleph0")
    __rsub__ = __sub__

    def __add__(self, b): 
        return self
    __radd__ = __add__

    def __mul__(self, b): 
        if b == 0: return b            
        return self
    __rmul__ = __mul__

    def __floordiv__(self, b):
        if b == 0: raise ZeroDivisionError
        return self
    __rfloordiv__ = __floordiv__
    __truediv__ = __floordiv__
    __rtuediv__ = __floordiv__
    __div__ = __floordiv__
    __rdiv__ = __floordiv__

    def __pow__(self, b):
        if b == 0: return 1
        return self
aleph0 = Aleph0()

def base(line):
    return line*(line+1)//2

def pairToN(pair):
    x,y = pair
    line,index = x+y,y
    return base(line)+index

def getNthPairInfo(N):
    # Avoid various singularities
    if N==0:
        return (0,0)

    # Gallop to find bounds for line
    line = 1
    next = 2
    while base(next)<=N:
        line = next
        next = line << 1
    
    # Binary search for starting line
    lo = line
    hi = line<<1
    while lo + 1 != hi:
        #assert base(lo) <= N < base(hi)
        mid = (lo + hi)>>1
        if base(mid)<=N:
            lo = mid
        else:
            hi = mid

    line = lo
    return line, N - base(line)

def getNthPair(N):
    line,index = getNthPairInfo(N)
    return (line - index, index)

def getNthPairBounded(N,W=aleph0,H=aleph0,useDivmod=False):
    """getNthPairBounded(N, W, H) -> (x, y)
    
    Return the N-th pair such that 0 <= x < W and 0 <= y < H."""

    if W <= 0 or H <= 0:
        raise ValueError("Invalid bounds")
    elif N >= W*H:
        raise ValueError("Invalid input (out of bounds)")

    # Simple case...
    if W is aleph0 and H is aleph0:
        return getNthPair(N)

    # Otherwise simplify by assuming W < H
    if H < W:
        x,y = getNthPairBounded(N,H,W,useDivmod=useDivmod)
        return y,x

    if useDivmod:
        return N%W,N//W
    else:
        # Conceptually we want to slide a diagonal line across a
        # rectangle. This gives more interesting results for large
        # bounds than using divmod.
        
        # If in lower left, just return as usual
        cornerSize = base(W)
        if N < cornerSize:
            return getNthPair(N)

        # Otherwise if in upper right, subtract from corner
        if H is not aleph0:
            M = W*H - N - 1
            if M < cornerSize:
                x,y = getNthPair(M)
                return (W-1-x,H-1-y)

        # Otherwise, compile line and index from number of times we
        # wrap.
        N = N - cornerSize
        index,offset = N%W,N//W
        # p = (W-1, 1+offset) + (-1,1)*index
        return (W-1-index, 1+offset+index)
def getNthPairBoundedChecked(N,W=aleph0,H=aleph0,useDivmod=False,GNP=getNthPairBounded):
    x,y = GNP(N,W,H,useDivmod)
    assert 0 <= x < W and 0 <= y < H
    return x,y

def getNthNTuple(N, W, H=aleph0, useLeftToRight=False):
    """getNthNTuple(N, W, H) -> (x_0, x_1, ..., x_W)

    Return the N-th W-tuple, where for 0 <= x_i < H."""

    if useLeftToRight:
        elts = [None]*W
        for i in range(W):
            elts[i],N = getNthPairBounded(N, H)
        return tuple(elts)
    else:
        if W==0:
            return ()
        elif W==1:
            return (N,)
        elif W==2:
            return getNthPairBounded(N, H, H)
        else:
            LW,RW = W//2, W - (W//2)
            L,R = getNthPairBounded(N, H**LW, H**RW)
            return (getNthNTuple(L,LW,H=H,useLeftToRight=useLeftToRight) + 
                    getNthNTuple(R,RW,H=H,useLeftToRight=useLeftToRight))
def getNthNTupleChecked(N, W, H=aleph0, useLeftToRight=False, GNT=getNthNTuple):
    t = GNT(N,W,H,useLeftToRight)
    assert len(t) == W
    for i in t:
        assert i < H
    return t

def getNthTuple(N, maxSize=aleph0, maxElement=aleph0, useDivmod=False, useLeftToRight=False):
    """getNthTuple(N, maxSize, maxElement) -> x

    Return the N-th tuple where len(x) < maxSize and for y in x, 0 <=
    y < maxElement."""

    # All zero sized tuples are isomorphic, don't ya know.
    if N == 0:
        return ()
    N -= 1
    if maxElement is not aleph0:
        if maxSize is aleph0:
            raise NotImplementedError('Max element size without max size unhandled')
        bounds = [maxElement**i for i in range(1, maxSize+1)]
        S,M = getNthPairVariableBounds(N, bounds)
    else:
        S,M = getNthPairBounded(N, maxSize, useDivmod=useDivmod)
    return getNthNTuple(M, S+1, maxElement, useLeftToRight=useLeftToRight)
def getNthTupleChecked(N, maxSize=aleph0, maxElement=aleph0, 
                       useDivmod=False, useLeftToRight=False, GNT=getNthTuple):
    # FIXME: maxsize is inclusive
    t = GNT(N,maxSize,maxElement,useDivmod,useLeftToRight)
    assert len(t) <= maxSize
    for i in t:
        assert i < maxElement
    return t

def getNthPairVariableBounds(N, bounds):
    """getNthPairVariableBounds(N, bounds) -> (x, y)

    Given a finite list of bounds (which may be finite or aleph0),
    return the N-th pair such that 0 <= x < len(bounds) and 0 <= y <
    bounds[x]."""

    if not bounds:
        raise ValueError("Invalid bounds")
    if not (0 <= N < sum(bounds)):
        raise ValueError("Invalid input (out of bounds)")

    level = 0
    active = range(len(bounds))
    active.sort(key=lambda i: bounds[i])
    prevLevel = 0
    for i,index in enumerate(active):
        level = bounds[index]
        W = len(active) - i
        if level is aleph0:
            H = aleph0
        else:
            H = level - prevLevel
        levelSize = W*H
        if N<levelSize: # Found the level
            idelta,delta = getNthPairBounded(N, W, H)
            return active[i+idelta],prevLevel+delta
        else:
            N -= levelSize
            prevLevel = level
    else:
        raise RuntimError("Unexpected loop completion")

def getNthPairVariableBoundsChecked(N, bounds, GNVP=getNthPairVariableBounds):
    x,y = GNVP(N,bounds)
    assert 0 <= x < len(bounds) and 0 <= y < bounds[x]
    return (x,y)

###

def testPairs():
    W = 3
    H = 6
    a = [['  ' for x in range(10)] for y in range(10)]
    b = [['  ' for x in range(10)] for y in range(10)]
    for i in range(min(W*H,40)):
        x,y = getNthPairBounded(i,W,H)
        x2,y2 = getNthPairBounded(i,W,H,useDivmod=True)
        print i,(x,y),(x2,y2)
        a[y][x] = '%2d'%i
        b[y2][x2] = '%2d'%i

    print '-- a --'
    for ln in a[::-1]:
        if ''.join(ln).strip():
            print '  '.join(ln)
    print '-- b --'
    for ln in b[::-1]:
        if ''.join(ln).strip():
            print '  '.join(ln)

def testPairsVB():
    bounds = [2,2,4,aleph0,5,aleph0]
    a = [['  ' for x in range(15)] for y in range(15)]
    b = [['  ' for x in range(15)] for y in range(15)]
    for i in range(min(sum(bounds),40)):
        x,y = getNthPairVariableBounds(i, bounds)
        print i,(x,y)
        a[y][x] = '%2d'%i

    print '-- a --'
    for ln in a[::-1]:
        if ''.join(ln).strip():
            print '  '.join(ln)

###

# Toggle to use checked versions of enumeration routines.
if False:
    getNthPairVariableBounds = getNthPairVariableBoundsChecked
    getNthPairBounded = getNthPairBoundedChecked
    getNthNTuple = getNthNTupleChecked
    getNthTuple = getNthTupleChecked

if __name__ == '__main__':
    testPairs()

    testPairsVB()

