blob: 6ebec7bdf1d1c4db6513c728493c93d92df3961e [file] [log] [blame]
"""Flexible enumeration of C types."""
from __future__ import division, print_function
from Enumeration import *
# TODO:
# - struct improvements (flexible arrays, packed &
# unpacked, alignment)
# - objective-c qualified id
# - anonymous / transparent unions
# - VLAs
# - block types
# - K&R functions
# - pass arguments of different types (test extension, transparent union)
# - varargs
###
# Actual type types
class Type(object):
def isBitField(self):
return False
def isPaddingBitField(self):
return False
def getTypeName(self, printer):
name = "T%d" % len(printer.types)
typedef = self.getTypedefDef(name, printer)
printer.addDeclaration(typedef)
return name
class BuiltinType(Type):
def __init__(self, name, size, bitFieldSize=None):
self.name = name
self.size = size
self.bitFieldSize = bitFieldSize
def isBitField(self):
return self.bitFieldSize is not None
def isPaddingBitField(self):
return self.bitFieldSize is 0
def getBitFieldSize(self):
assert self.isBitField()
return self.bitFieldSize
def getTypeName(self, printer):
return self.name
def sizeof(self):
return self.size
def __str__(self):
return self.name
class EnumType(Type):
unique_id = 0
def __init__(self, index, enumerators):
self.index = index
self.enumerators = enumerators
self.unique_id = self.__class__.unique_id
self.__class__.unique_id += 1
def getEnumerators(self):
result = ""
for i, init in enumerate(self.enumerators):
if i > 0:
result = result + ", "
result = result + "enum%dval%d_%d" % (self.index, i, self.unique_id)
if init:
result = result + " = %s" % (init)
return result
def __str__(self):
return "enum { %s }" % (self.getEnumerators())
def getTypedefDef(self, name, printer):
return "typedef enum %s { %s } %s;" % (name, self.getEnumerators(), name)
class RecordType(Type):
def __init__(self, index, isUnion, fields):
self.index = index
self.isUnion = isUnion
self.fields = fields
self.name = None
def __str__(self):
def getField(t):
if t.isBitField():
return "%s : %d;" % (t, t.getBitFieldSize())
else:
return "%s;" % t
return "%s { %s }" % (
("struct", "union")[self.isUnion],
" ".join(map(getField, self.fields)),
)
def getTypedefDef(self, name, printer):
def getField(it):
i, t = it
if t.isBitField():
if t.isPaddingBitField():
return "%s : 0;" % (printer.getTypeName(t),)
else:
return "%s field%d : %d;" % (
printer.getTypeName(t),
i,
t.getBitFieldSize(),
)
else:
return "%s field%d;" % (printer.getTypeName(t), i)
fields = [getField(f) for f in enumerate(self.fields)]
# Name the struct for more readable LLVM IR.
return "typedef %s %s { %s } %s;" % (
("struct", "union")[self.isUnion],
name,
" ".join(fields),
name,
)
class ArrayType(Type):
def __init__(self, index, isVector, elementType, size):
if isVector:
# Note that for vectors, this is the size in bytes.
assert size > 0
else:
assert size is None or size >= 0
self.index = index
self.isVector = isVector
self.elementType = elementType
self.size = size
if isVector:
eltSize = self.elementType.sizeof()
assert not (self.size % eltSize)
self.numElements = self.size // eltSize
else:
self.numElements = self.size
def __str__(self):
if self.isVector:
return "vector (%s)[%d]" % (self.elementType, self.size)
elif self.size is not None:
return "(%s)[%d]" % (self.elementType, self.size)
else:
return "(%s)[]" % (self.elementType,)
def getTypedefDef(self, name, printer):
elementName = printer.getTypeName(self.elementType)
if self.isVector:
return "typedef %s %s __attribute__ ((vector_size (%d)));" % (
elementName,
name,
self.size,
)
else:
if self.size is None:
sizeStr = ""
else:
sizeStr = str(self.size)
return "typedef %s %s[%s];" % (elementName, name, sizeStr)
class ComplexType(Type):
def __init__(self, index, elementType):
self.index = index
self.elementType = elementType
def __str__(self):
return "_Complex (%s)" % (self.elementType)
def getTypedefDef(self, name, printer):
return "typedef _Complex %s %s;" % (printer.getTypeName(self.elementType), name)
class FunctionType(Type):
def __init__(self, index, returnType, argTypes):
self.index = index
self.returnType = returnType
self.argTypes = argTypes
def __str__(self):
if self.returnType is None:
rt = "void"
else:
rt = str(self.returnType)
if not self.argTypes:
at = "void"
else:
at = ", ".join(map(str, self.argTypes))
return "%s (*)(%s)" % (rt, at)
def getTypedefDef(self, name, printer):
if self.returnType is None:
rt = "void"
else:
rt = str(self.returnType)
if not self.argTypes:
at = "void"
else:
at = ", ".join(map(str, self.argTypes))
return "typedef %s (*%s)(%s);" % (rt, name, at)
###
# Type enumerators
class TypeGenerator(object):
def __init__(self):
self.cache = {}
def setCardinality(self):
abstract
def get(self, N):
T = self.cache.get(N)
if T is None:
assert 0 <= N < self.cardinality
T = self.cache[N] = self.generateType(N)
return T
def generateType(self, N):
abstract
class FixedTypeGenerator(TypeGenerator):
def __init__(self, types):
TypeGenerator.__init__(self)
self.types = types
self.setCardinality()
def setCardinality(self):
self.cardinality = len(self.types)
def generateType(self, N):
return self.types[N]
# Factorial
def fact(n):
result = 1
while n > 0:
result = result * n
n = n - 1
return result
# Compute the number of combinations (n choose k)
def num_combinations(n, k):
return fact(n) // (fact(k) * fact(n - k))
# Enumerate the combinations choosing k elements from the list of values
def combinations(values, k):
# From ActiveState Recipe 190465: Generator for permutations,
# combinations, selections of a sequence
if k == 0:
yield []
else:
for i in range(len(values) - k + 1):
for cc in combinations(values[i + 1 :], k - 1):
yield [values[i]] + cc
class EnumTypeGenerator(TypeGenerator):
def __init__(self, values, minEnumerators, maxEnumerators):
TypeGenerator.__init__(self)
self.values = values
self.minEnumerators = minEnumerators
self.maxEnumerators = maxEnumerators
self.setCardinality()
def setCardinality(self):
self.cardinality = 0
for num in range(self.minEnumerators, self.maxEnumerators + 1):
self.cardinality += num_combinations(len(self.values), num)
def generateType(self, n):
# Figure out the number of enumerators in this type
numEnumerators = self.minEnumerators
valuesCovered = 0
while numEnumerators < self.maxEnumerators:
comb = num_combinations(len(self.values), numEnumerators)
if valuesCovered + comb > n:
break
numEnumerators = numEnumerators + 1
valuesCovered += comb
# Find the requested combination of enumerators and build a
# type from it.
i = 0
for enumerators in combinations(self.values, numEnumerators):
if i == n - valuesCovered:
return EnumType(n, enumerators)
i = i + 1
assert False
class ComplexTypeGenerator(TypeGenerator):
def __init__(self, typeGen):
TypeGenerator.__init__(self)
self.typeGen = typeGen
self.setCardinality()
def setCardinality(self):
self.cardinality = self.typeGen.cardinality
def generateType(self, N):
return ComplexType(N, self.typeGen.get(N))
class VectorTypeGenerator(TypeGenerator):
def __init__(self, typeGen, sizes):
TypeGenerator.__init__(self)
self.typeGen = typeGen
self.sizes = tuple(map(int, sizes))
self.setCardinality()
def setCardinality(self):
self.cardinality = len(self.sizes) * self.typeGen.cardinality
def generateType(self, N):
S, T = getNthPairBounded(N, len(self.sizes), self.typeGen.cardinality)
return ArrayType(N, True, self.typeGen.get(T), self.sizes[S])
class FixedArrayTypeGenerator(TypeGenerator):
def __init__(self, typeGen, sizes):
TypeGenerator.__init__(self)
self.typeGen = typeGen
self.sizes = tuple(size)
self.setCardinality()
def setCardinality(self):
self.cardinality = len(self.sizes) * self.typeGen.cardinality
def generateType(self, N):
S, T = getNthPairBounded(N, len(self.sizes), self.typeGen.cardinality)
return ArrayType(N, false, self.typeGen.get(T), self.sizes[S])
class ArrayTypeGenerator(TypeGenerator):
def __init__(self, typeGen, maxSize, useIncomplete=False, useZero=False):
TypeGenerator.__init__(self)
self.typeGen = typeGen
self.useIncomplete = useIncomplete
self.useZero = useZero
self.maxSize = int(maxSize)
self.W = useIncomplete + useZero + self.maxSize
self.setCardinality()
def setCardinality(self):
self.cardinality = self.W * self.typeGen.cardinality
def generateType(self, N):
S, T = getNthPairBounded(N, self.W, self.typeGen.cardinality)
if self.useIncomplete:
if S == 0:
size = None
S = None
else:
S = S - 1
if S is not None:
if self.useZero:
size = S
else:
size = S + 1
return ArrayType(N, False, self.typeGen.get(T), size)
class RecordTypeGenerator(TypeGenerator):
def __init__(self, typeGen, useUnion, maxSize):
TypeGenerator.__init__(self)
self.typeGen = typeGen
self.useUnion = bool(useUnion)
self.maxSize = int(maxSize)
self.setCardinality()
def setCardinality(self):
M = 1 + self.useUnion
if self.maxSize is aleph0:
S = aleph0 * self.typeGen.cardinality
else:
S = 0
for i in range(self.maxSize + 1):
S += M * (self.typeGen.cardinality**i)
self.cardinality = S
def generateType(self, N):
isUnion, I = False, N
if self.useUnion:
isUnion, I = (I & 1), I >> 1
fields = [
self.typeGen.get(f)
for f in getNthTuple(I, self.maxSize, self.typeGen.cardinality)
]
return RecordType(N, isUnion, fields)
class FunctionTypeGenerator(TypeGenerator):
def __init__(self, typeGen, useReturn, maxSize):
TypeGenerator.__init__(self)
self.typeGen = typeGen
self.useReturn = useReturn
self.maxSize = maxSize
self.setCardinality()
def setCardinality(self):
if self.maxSize is aleph0:
S = aleph0 * self.typeGen.cardinality()
elif self.useReturn:
S = 0
for i in range(1, self.maxSize + 1 + 1):
S += self.typeGen.cardinality**i
else:
S = 0
for i in range(self.maxSize + 1):
S += self.typeGen.cardinality**i
self.cardinality = S
def generateType(self, N):
if self.useReturn:
# Skip the empty tuple
argIndices = getNthTuple(N + 1, self.maxSize + 1, self.typeGen.cardinality)
retIndex, argIndices = argIndices[0], argIndices[1:]
retTy = self.typeGen.get(retIndex)
else:
retTy = None
argIndices = getNthTuple(N, self.maxSize, self.typeGen.cardinality)
args = [self.typeGen.get(i) for i in argIndices]
return FunctionType(N, retTy, args)
class AnyTypeGenerator(TypeGenerator):
def __init__(self):
TypeGenerator.__init__(self)
self.generators = []
self.bounds = []
self.setCardinality()
self._cardinality = None
def getCardinality(self):
if self._cardinality is None:
return aleph0
else:
return self._cardinality
def setCardinality(self):
self.bounds = [g.cardinality for g in self.generators]
self._cardinality = sum(self.bounds)
cardinality = property(getCardinality, None)
def addGenerator(self, g):
self.generators.append(g)
for i in range(100):
prev = self._cardinality
self._cardinality = None
for g in self.generators:
g.setCardinality()
self.setCardinality()
if (self._cardinality is aleph0) or prev == self._cardinality:
break
else:
raise RuntimeError("Infinite loop in setting cardinality")
def generateType(self, N):
index, M = getNthPairVariableBounds(N, self.bounds)
return self.generators[index].get(M)
def test():
fbtg = FixedTypeGenerator(
[BuiltinType("char", 4), BuiltinType("char", 4, 0), BuiltinType("int", 4, 5)]
)
fields1 = AnyTypeGenerator()
fields1.addGenerator(fbtg)
fields0 = AnyTypeGenerator()
fields0.addGenerator(fbtg)
# fields0.addGenerator( RecordTypeGenerator(fields1, False, 4) )
btg = FixedTypeGenerator([BuiltinType("char", 4), BuiltinType("int", 4)])
etg = EnumTypeGenerator([None, "-1", "1", "1u"], 0, 3)
atg = AnyTypeGenerator()
atg.addGenerator(btg)
atg.addGenerator(RecordTypeGenerator(fields0, False, 4))
atg.addGenerator(etg)
print("Cardinality:", atg.cardinality)
for i in range(100):
if i == atg.cardinality:
try:
atg.get(i)
raise RuntimeError("Cardinality was wrong")
except AssertionError:
break
print("%4d: %s" % (i, atg.get(i)))
if __name__ == "__main__":
test()