from __future__ import print_function
import struct
import sys

import gdb.printing
import gdb.types


class Iterator:
    def __iter__(self):
        return self

    if sys.version_info.major == 2:

        def next(self):
            return self.__next__()

    def children(self):
        return self


class SmallStringPrinter:
    """Print an llvm::SmallString object."""

    def __init__(self, val):
        self.val = val

    def to_string(self):
        data = self.val["BeginX"].cast(gdb.lookup_type("char").pointer())
        length = self.val["Size"]
        return data.lazy_string(length=length)

    def display_hint(self):
        return "string"


class StringRefPrinter:
    """Print an llvm::StringRef object."""

    def __init__(self, val):
        self.val = val

    def to_string(self):
        data = self.val["Data"]
        length = self.val["Length"]
        return data.lazy_string(length=length)

    def display_hint(self):
        return "string"


class SmallVectorPrinter(Iterator):
    """Print an llvm::SmallVector object."""

    def __init__(self, val):
        self.val = val
        t = val.type.template_argument(0).pointer()
        self.begin = val["BeginX"].cast(t)
        self.size = val["Size"]
        self.i = 0

    def __next__(self):
        if self.i == self.size:
            raise StopIteration
        ret = "[{}]".format(self.i), (self.begin + self.i).dereference()
        self.i += 1
        return ret

    def to_string(self):
        return "llvm::SmallVector of Size {}, Capacity {}".format(
            self.size, self.val["Capacity"]
        )

    def display_hint(self):
        return "array"


class ArrayRefPrinter:
    """Print an llvm::ArrayRef object."""

    class _iterator:
        def __init__(self, begin, end):
            self.cur = begin
            self.end = end
            self.count = 0

        def __iter__(self):
            return self

        def __next__(self):
            if self.cur == self.end:
                raise StopIteration
            count = self.count
            self.count = self.count + 1
            cur = self.cur
            self.cur = self.cur + 1
            return "[%d]" % count, cur.dereference()

        if sys.version_info.major == 2:
            next = __next__

    def __init__(self, val):
        self.val = val

    def children(self):
        data = self.val["Data"]
        return self._iterator(data, data + self.val["Length"])

    def to_string(self):
        return "llvm::ArrayRef of length %d" % (self.val["Length"])

    def display_hint(self):
        return "array"


class ExpectedPrinter(Iterator):
    """Print an llvm::Expected object."""

    def __init__(self, val):
        self.val = val

    def __next__(self):
        val = self.val
        if val is None:
            raise StopIteration
        self.val = None
        if val["HasError"]:
            return (
                "error",
                val["ErrorStorage"]
                .address.cast(gdb.lookup_type("llvm::ErrorInfoBase").pointer())
                .dereference(),
            )
        return (
            "value",
            val["TStorage"]
            .address.cast(val.type.template_argument(0).pointer())
            .dereference(),
        )

    def to_string(self):
        return "llvm::Expected{}".format(" is error" if self.val["HasError"] else "")


class OptionalPrinter(Iterator):
    """Print an llvm::Optional object."""

    def __init__(self, val):
        self.val = val

    def __next__(self):
        val = self.val
        if val is None:
            raise StopIteration
        self.val = None
        if not val["Storage"]["hasVal"]:
            raise StopIteration
        return ("value", val["Storage"]["val"])

    def to_string(self):
        return "llvm::Optional{}".format(
            "" if self.val["Storage"]["hasVal"] else " is not initialized"
        )


class DenseMapPrinter:
    "Print a DenseMap"

    class _iterator:
        def __init__(self, key_info_t, begin, end):
            self.key_info_t = key_info_t
            self.cur = begin
            self.end = end
            self.advancePastEmptyBuckets()
            self.first = True

        def __iter__(self):
            return self

        def advancePastEmptyBuckets(self):
            # disabled until the comments below can be addressed
            # keeping as notes/posterity/hints for future contributors
            return
            n = self.key_info_t.name
            is_equal = gdb.parse_and_eval(n + "::isEqual")
            empty = gdb.parse_and_eval(n + "::getEmptyKey()")
            tombstone = gdb.parse_and_eval(n + "::getTombstoneKey()")
            # the following is invalid, GDB fails with:
            #   Python Exception <class 'gdb.error'> Attempt to take address of value
            #   not located in memory.
            # because isEqual took parameter (for the unsigned long key I was testing)
            # by const ref, and GDB
            # It's also not entirely general - we should be accessing the "getFirst()"
            # member function, not the 'first' member variable, but I've yet to figure
            # out how to find/call member functions (especially (const) overloaded
            # ones) on a gdb.Value.
            while self.cur != self.end and (
                is_equal(self.cur.dereference()["first"], empty)
                or is_equal(self.cur.dereference()["first"], tombstone)
            ):
                self.cur = self.cur + 1

        def __next__(self):
            if self.cur == self.end:
                raise StopIteration
            cur = self.cur
            v = cur.dereference()["first" if self.first else "second"]
            if not self.first:
                self.cur = self.cur + 1
                self.advancePastEmptyBuckets()
                self.first = True
            else:
                self.first = False
            return "x", v

        if sys.version_info.major == 2:
            next = __next__

    def __init__(self, val):
        self.val = val

    def children(self):
        t = self.val.type.template_argument(3).pointer()
        begin = self.val["Buckets"].cast(t)
        end = (begin + self.val["NumBuckets"]).cast(t)
        return self._iterator(self.val.type.template_argument(2), begin, end)

    def to_string(self):
        return "llvm::DenseMap with %d elements" % (self.val["NumEntries"])

    def display_hint(self):
        return "map"


class StringMapPrinter:
    "Print a StringMap"

    def __init__(self, val):
        self.val = val

    def children(self):
        it = self.val["TheTable"]
        end = it + self.val["NumBuckets"]
        value_ty = self.val.type.template_argument(0)
        entry_base_ty = gdb.lookup_type("llvm::StringMapEntryBase")
        tombstone = gdb.parse_and_eval("llvm::StringMapImpl::TombstoneIntVal")

        while it != end:
            it_deref = it.dereference()
            if it_deref == 0 or it_deref == tombstone:
                it = it + 1
                continue

            entry_ptr = it_deref.cast(entry_base_ty.pointer())
            entry = entry_ptr.dereference()

            str_len = entry["keyLength"]
            value_ptr = (entry_ptr + 1).cast(value_ty.pointer())
            str_data = (entry_ptr + 1).cast(gdb.lookup_type("uintptr_t")) + max(
                value_ty.sizeof, entry_base_ty.alignof
            )
            str_data = str_data.cast(gdb.lookup_type("char").const().pointer())
            string_ref = gdb.Value(
                struct.pack("PN", int(str_data), int(str_len)),
                gdb.lookup_type("llvm::StringRef"),
            )
            yield "key", string_ref

            value = value_ptr.dereference()
            yield "value", value

            it = it + 1

    def to_string(self):
        return "llvm::StringMap with %d elements" % (self.val["NumItems"])

    def display_hint(self):
        return "map"


class TwinePrinter:
    "Print a Twine"

    def __init__(self, val):
        self._val = val

    def display_hint(self):
        return "string"

    def string_from_pretty_printer_lookup(self, val):
        """Lookup the default pretty-printer for val and use it.

        If no pretty-printer is defined for the type of val, print an error and
        return a placeholder string."""

        pp = gdb.default_visualizer(val)
        if pp:
            s = pp.to_string()

            # The pretty-printer may return a LazyString instead of an actual Python
            # string.  Convert it to a Python string.  However, GDB doesn't seem to
            # register the LazyString type, so we can't check
            # "type(s) == gdb.LazyString".
            if "LazyString" in type(s).__name__:
                s = s.value().string()

        else:
            print(
                (
                    "No pretty printer for {} found. The resulting Twine "
                    + "representation will be incomplete."
                ).format(val.type.name)
            )
            s = "(missing {})".format(val.type.name)

        return s

    def is_twine_kind(self, kind, expected):
        if not kind.endswith(expected):
            return False
        # apparently some GDB versions add the NodeKind:: namespace
        # (happens for me on GDB 7.11)
        return kind in (
            "llvm::Twine::" + expected,
            "llvm::Twine::NodeKind::" + expected,
        )

    def string_from_child(self, child, kind):
        """Return the string representation of the Twine::Child child."""

        if self.is_twine_kind(kind, "EmptyKind") or self.is_twine_kind(
            kind, "NullKind"
        ):
            return ""

        if self.is_twine_kind(kind, "TwineKind"):
            return self.string_from_twine_object(child["twine"].dereference())

        if self.is_twine_kind(kind, "CStringKind"):
            return child["cString"].string()

        if self.is_twine_kind(kind, "StdStringKind"):
            val = child["stdString"].dereference()
            return self.string_from_pretty_printer_lookup(val)

        if self.is_twine_kind(kind, "PtrAndLengthKind"):
            val = child["ptrAndLength"]
            data = val["ptr"]
            length = val["length"]
            return data.string(length=length)

        if self.is_twine_kind(kind, "CharKind"):
            return chr(child["character"])

        if self.is_twine_kind(kind, "DecUIKind"):
            return str(child["decUI"])

        if self.is_twine_kind(kind, "DecIKind"):
            return str(child["decI"])

        if self.is_twine_kind(kind, "DecULKind"):
            return str(child["decUL"].dereference())

        if self.is_twine_kind(kind, "DecLKind"):
            return str(child["decL"].dereference())

        if self.is_twine_kind(kind, "DecULLKind"):
            return str(child["decULL"].dereference())

        if self.is_twine_kind(kind, "DecLLKind"):
            return str(child["decLL"].dereference())

        if self.is_twine_kind(kind, "UHexKind"):
            val = child["uHex"].dereference()
            return hex(int(val))

        print(
            (
                "Unhandled NodeKind {} in Twine pretty-printer. The result will be "
                "incomplete."
            ).format(kind)
        )

        return "(unhandled {})".format(kind)

    def string_from_twine_object(self, twine):
        """Return the string representation of the Twine object twine."""

        lhs = twine["LHS"]
        rhs = twine["RHS"]

        lhs_kind = str(twine["LHSKind"])
        rhs_kind = str(twine["RHSKind"])

        lhs_str = self.string_from_child(lhs, lhs_kind)
        rhs_str = self.string_from_child(rhs, rhs_kind)

        return lhs_str + rhs_str

    def to_string(self):
        return self.string_from_twine_object(self._val)

    def display_hint(self):
        return "string"


def get_pointer_int_pair(val):
    """Get tuple from llvm::PointerIntPair."""
    info_name = val.type.template_argument(4).strip_typedefs().name
    # Note: this throws a gdb.error if the info type is not used (by means of a
    # call to getPointer() or similar) in the current translation unit.
    enum_type = gdb.lookup_type(info_name + "::MaskAndShiftConstants")
    enum_dict = gdb.types.make_enum_dict(enum_type)
    ptr_mask = enum_dict[info_name + "::PointerBitMask"]
    int_shift = enum_dict[info_name + "::IntShift"]
    int_mask = enum_dict[info_name + "::IntMask"]
    pair_union = val["Value"]
    pointer = pair_union & ptr_mask
    value = (pair_union >> int_shift) & int_mask
    return (pointer, value)


class PointerIntPairPrinter:
    """Print a PointerIntPair."""

    def __init__(self, pointer, value):
        self.pointer = pointer
        self.value = value

    def children(self):
        yield ("pointer", self.pointer)
        yield ("value", self.value)

    def to_string(self):
        return "(%s, %s)" % (self.pointer.type, self.value.type)


def make_pointer_int_pair_printer(val):
    """Factory for an llvm::PointerIntPair printer."""
    try:
        pointer, value = get_pointer_int_pair(val)
    except gdb.error:
        return None  # If PointerIntPair cannot be analyzed, print as raw value.
    pointer_type = val.type.template_argument(0)
    value_type = val.type.template_argument(2)
    return PointerIntPairPrinter(pointer.cast(pointer_type), value.cast(value_type))


class PointerUnionPrinter:
    """Print a PointerUnion."""

    def __init__(self, pointer):
        self.pointer = pointer

    def children(self):
        yield ("pointer", self.pointer)

    def to_string(self):
        return "Containing %s" % self.pointer.type


def make_pointer_union_printer(val):
    """Factory for an llvm::PointerUnion printer."""
    try:
        pointer, value = get_pointer_int_pair(val["Val"])
    except gdb.error:
        return None  # If PointerIntPair cannot be analyzed, print as raw value.
    pointer_type = val.type.template_argument(int(value))
    return PointerUnionPrinter(pointer.cast(pointer_type))


class IlistNodePrinter:
    """Print an llvm::ilist_node object."""

    def __init__(self, val):
        impl_type = val.type.fields()[0].type
        base_type = impl_type.fields()[0].type
        derived_type = val.type.template_argument(0)

        def get_prev_and_sentinel(base):
            # One of Prev and PrevAndSentinel exists. Depending on #defines used to
            # compile LLVM, the base_type's template argument is either true of false.
            if base_type.template_argument(0):
                return get_pointer_int_pair(base["PrevAndSentinel"])
            return base["Prev"], None

        # Casts a base_type pointer to the appropriate derived type.
        def cast_pointer(pointer):
            sentinel = get_prev_and_sentinel(pointer.dereference())[1]
            pointer = pointer.cast(impl_type.pointer())
            if sentinel:
                return pointer
            return pointer.cast(derived_type.pointer())

        # Repeated cast becaue val.type's base_type is ambiguous when using tags.
        base = val.cast(impl_type).cast(base_type)
        (prev, sentinel) = get_prev_and_sentinel(base)
        prev = prev.cast(base_type.pointer())
        self.prev = cast_pointer(prev)
        self.next = cast_pointer(val["Next"])
        self.sentinel = sentinel

    def children(self):
        if self.sentinel:
            yield "sentinel", "yes"
        yield "prev", self.prev
        yield "next", self.next


class IlistPrinter:
    """Print an llvm::simple_ilist or llvm::iplist object."""

    def __init__(self, val):
        self.node_type = val.type.template_argument(0)
        sentinel = val["Sentinel"]
        # First field is common base type of sentinel and ilist_node.
        base_type = sentinel.type.fields()[0].type
        self.sentinel = sentinel.address.cast(base_type.pointer())

    def _pointers(self):
        pointer = self.sentinel
        while True:
            pointer = pointer["Next"].cast(pointer.type)
            if pointer == self.sentinel:
                return
            yield pointer.cast(self.node_type.pointer())

    def children(self):
        for k, v in enumerate(self._pointers()):
            yield ("[%d]" % k, v.dereference())


pp = gdb.printing.RegexpCollectionPrettyPrinter("LLVMSupport")
pp.add_printer("llvm::SmallString", "^llvm::SmallString<.*>$", SmallStringPrinter)
pp.add_printer("llvm::StringRef", "^llvm::StringRef$", StringRefPrinter)
pp.add_printer(
    "llvm::SmallVectorImpl", "^llvm::SmallVector(Impl)?<.*>$", SmallVectorPrinter
)
pp.add_printer("llvm::ArrayRef", "^llvm::(Mutable)?ArrayRef<.*>$", ArrayRefPrinter)
pp.add_printer("llvm::Expected", "^llvm::Expected<.*>$", ExpectedPrinter)
pp.add_printer("llvm::Optional", "^llvm::Optional<.*>$", OptionalPrinter)
pp.add_printer("llvm::DenseMap", "^llvm::DenseMap<.*>$", DenseMapPrinter)
pp.add_printer("llvm::StringMap", "^llvm::StringMap<.*>$", StringMapPrinter)
pp.add_printer("llvm::Twine", "^llvm::Twine$", TwinePrinter)
pp.add_printer(
    "llvm::PointerIntPair", "^llvm::PointerIntPair<.*>$", make_pointer_int_pair_printer
)
pp.add_printer(
    "llvm::PointerUnion", "^llvm::PointerUnion<.*>$", make_pointer_union_printer
)
pp.add_printer("llvm::ilist_node", "^llvm::ilist_node<.*>$", IlistNodePrinter)
pp.add_printer("llvm::iplist", "^llvm::iplist<.*>$", IlistPrinter)
pp.add_printer("llvm::simple_ilist", "^llvm::simple_ilist<.*>$", IlistPrinter)
gdb.printing.register_pretty_printer(gdb.current_objfile(), pp)
