#!/usr/bin/env python2.4

import os, os.path, re, sys
from warnings import warn

_default_home = "/usr/src/kernel-patches/all/2.6.18/debian"
_default_revisions = "1 2 3 4 5 6 7 8 9 10 11 12 12etch1 12etch2 13 13etch1 13etch2 13etch3 13etch4 13etch5 13etch6 14 15 16 17 17etch1 18 18etch1 18etch2 18etch3 18etch4 18etch5 18etch6 19 20 21 22 22etch1 22etch2 22etch3 23 23etch1 24 24etch1 24etch2 24etch3 24etch4 25 26 26etch1 26etch2"
_default_source = "2.6.18-26etch2"

class series(list):
    def __init__(self, name, home, reverse = False):
        self.name = name
        self.reverse = reverse

        filename = os.path.join(home, 'series', name)
        if not os.path.exists(filename):
            raise RuntimeError, "Can't find series file for %s" % name

        f = file(filename)
        for line in f.readlines():
            line = line.strip()

            if len(line) == 0 or line[0] == '#':
                continue

            items = line.split(' ')
            if len(items) != 2:
                raise RuntimeError, "Line '%s' in file %s malformed." % (line, filename)
            else:
                operation, patch = items

            if operation in ('+', '-'):
                patchfile = os.path.join(home, patch)
                for suffix, type in (('', 0), ('.bz2', 1), ('.gz', 2)):
                    if os.path.exists(patchfile + suffix):
                        patchinfo = patchfile + suffix, type
                        break
                else:
                    raise RuntimeError, "Can't find patch %s for series %s" % (patchfile, name)
            else:
                raise RuntimeError, 'Undefined operation "%s" in series %s' % (operation, name)

            self.append((operation, patch, patchinfo))

    def __repr__(self):
        return '<%s object for %s>' % (self.__class__.__name__, self.name)

    def apply(self):
        if self.reverse:
            for operation, patch, patchinfo in self[::-1]:
                if operation == '.':
                    print """\
  (.) IGNORED   %s\
""" % patch
                elif operation == '+':
                    self.patch_deapply(patch, patchinfo)
                elif operation == '-':
                    self.patch_apply(patch, patchinfo)
            print "--> %s fully unapplied." % self.name

        else:
            for operation, patch, patchinfo in self:
                if operation == '.':
                    print """\
  (.) IGNORED   %s\
""" % patch
                elif operation == '+':
                    self.patch_apply(patch, patchinfo)
                elif operation == '-':
                    self.patch_deapply(patch, patchinfo)
            print "--> %s fully applied." % self.name

    def patch_apply(self, patch, patchinfo):
        ret = self.patch_call(patchinfo, '--fuzz=1')
        if ret == 0:
            print """\
  (+) OK        %s\
""" % patch
        else:
            print """\
  (+) FAIL      %s\
""" % patch
            raise SystemExit, 1

    def patch_call(self, patchinfo, patchargs):
        patchfile, type = patchinfo
        cmdline = []
        if type == 0:
            cmdline.append('cat')
        elif type == 1:
            cmdline.append('bzcat')
        elif type == 2:
            cmdline.append('zcat')
        cmdline.append(patchfile + ' | patch -p1 -f -s -t --no-backup-if-mismatch')
        cmdline.append(patchargs)
        return os.spawnv(os.P_WAIT, '/bin/sh', ['sh', '-c', ' '.join(cmdline)])

    def patch_deapply(self, patch, patchinfo):
        ret = self.patch_call(patchinfo, '-R')
        if ret == 0:
            print """\
  (-) OK        %s\
""" % patch
        else:
            print """\
  (-) FAIL      %s\
""" % patch
            raise SystemExit, 1

    @staticmethod
    def read_all(s, home, reverse = False):
        ret = []

        for name in s:
            filename = os.path.join(home, 'series', name)
            if not os.path.exists(filename):
                warn("Can't find series file for %s" % name)
            else:
                i = series(name, home, reverse)
                ret.append(i)

        return ret

class series_extra(series):
    def __init__(self, name, home, extra, reverse = False):
        self.extra = extra
        self.extra_used = ()
        self.name = name
        self.reverse = reverse

        filename = os.path.join(home, 'series', name + '-extra')
        if not os.path.exists(filename):
            raise RuntimeError, "Can't find series file for %s" % name

        f = file(filename)
        for line in f.readlines():
            line = line.strip()

            if len(line) == 0 or line[0] == '#':
                continue

            items = line.split(' ')
            operation, patch = items[:2]

            if operation in ('+', '-'):
                patchfile = os.path.join(home, patch)
                for suffix, type in (('', 0), ('.bz2', 1), ('.gz', 2)):
                    if os.path.exists(patchfile + suffix):
                        patchinfo = patchfile + suffix, type
                        break
                else:
                    raise RuntimeError, "Can't find patch %s for series %s" % (patchfile, name)
            else:
                raise RuntimeError, 'Undefined operation "%s" in series %s' % (operation, name)

            extra = {}
            for s in items[2:]:
                s = tuple(s.split('_'))
                if len(s) > 3:
                    raise RuntimeError, "parse error"
                if len(s) == 3:
                    raise RuntimeError, "Patch per flavour is not supported currently"
                extra[s] = True
            if not self._check_extra(extra):
                operation = '.'

            self.append((operation, patch, patchinfo))

    def _check_extra(self, extra):
        for i in (1, 2, 3):
            t = self.extra[:i]
            if extra.has_key(t):
                if i > len(self.extra_used):
                    self.extra_used = t
                return True
        for i in (2, 3):
            t = ('*',) + self.extra[1:i]
            if extra.has_key(t):
                if i > len(self.extra_used):
                    self.extra_used = t
                return True
        return False

    @staticmethod
    def read_all(s, home, extra, reverse = False):
        ret = []

        for name in s:
            filename = os.path.join(home, 'series', name + '-extra')
            if not os.path.exists(filename):
                warn("Can't find series file for %s" % name)
            else:
                i = series_extra(name, home, extra, reverse)
                ret.append(i)

        return ret

class series_orig(series):
    def __init__(self, name, home):
        self.name = name

        filename = os.path.join(home, 'series', 'orig-' + name)
        if not os.path.exists(filename):
            raise RuntimeError, "Can't find series file for %s" % name

        f = file(filename)
        for line in f.readlines():
            line = line.strip()

            if len(line) == 0 or line[0] == '#':
                continue

            items = line.split(' ')
            operation, patch = items[:2]

            if operation == '+':
                patchfile = os.path.join(home, patch)
                for suffix, type in (('', 0), ('.bz2', 1), ('.gz', 2)):
                    if os.path.exists(patchfile + suffix):
                        patchinfo = patchfile + suffix, type
                        break
                else:
                    raise RuntimeError, "Can't find patch %s for series %s" % (patchfile, name)
            elif operation == 'X':
                patchfile = os.path.join(home, patch)
                if os.path.exists(patchfile):
                    patchinfo = patchfile
                else:
                    raise RuntimeError, "Can't find patch %s for series %s" % (patchfile, name)
            else:
                raise RuntimeError, 'Undefined operation "%s" in series %s' % (operation, name)

            self.append((operation, patch, patchinfo))

    def apply(self):
        for operation, patch, patchinfo in self:
            if operation == '+':
                self.patch_apply(patch, patchinfo)
            elif operation == 'X':
                self.files_remove(patch, patchinfo)
        print "--> %s fully applied." % self.name

    def files_remove(self, patch, patchinfo):
        for line in file(patchinfo):
            line = line.strip()
            if len(line) == 0 or line[0] == '#':
                continue

            os.unlink(line)

        print """\
  (X) OK        %s\
""" % patch

class version(object):
    __slots__ = "upstream", "revision"

    def __init__(self, string = None):
        if string is not None:
            self.upstream, self.revision = self.parse(string)

    def __str__(self):
        return "%s-%s" % (self.upstream, self.revision)

    _re = r"""
^
(
    (?:
        \d+\.\d+\.\d+\+
    )?
    \d+\.\d+\.\d+
    (?:
        -.+?
    )?
)
-
([^-]+)
$
"""

    def parse(self, version):
        match = re.match(self._re, version, re.X)
        if match is None:
            raise ValueError
        return match.groups()

class version_file(object):
    _file = 'version.Debian'
    extra = ()
    in_progress = False

    def __init__(self, ver = None, overwrite = False):
        if overwrite:
            self._read(ver)
        elif os.path.exists(self._file):
            s = file(self._file).readline().strip()
            self._read(s)
        elif ver:
            warn('No %s file, assuming pristine Linux %s' % (self._file, ver.upstream))
            self.version = version()
            self.version.upstream = ver.upstream
            self.version.revision = '0'
        else:
            raise RuntimeError, "Not possible to determine version"

    def __str__(self):
        if self.in_progress:
            return "unstable"
        if self.extra:
            return "%s %s" % (self.version, '_'.join(self.extra))
        return str(self.version)

    def _read(self, s):
        list = s.split(' ')
        if len(list) > 2:
            raise RuntimeError, "Can't parse %s" % self._file
        try:
            self.version = version(list[0])
        except ValueError:
            raise RuntimeError, 'Can\'t read version in %s: "%s"' % (self._file, list[0])
        if len(list) == 2:
            self.extra = tuple(list[1].split('_'))

    def _write(self):
        if os.path.lexists(self._file):
            os.unlink(self._file)
        file(self._file, 'w').write('%s\n' % self)

    def begin(self):
        self.in_progress = True
        self._write()

    def commit(self, version = None, extra = None):
        self.in_progress = False
        if version is not None:
            self.version = version
        if extra is not None:
            self.extra = extra
        self._write()

def main():
    options, args = parse_options()

    if len(args) > 1:
        print "Too much arguments"
        return

    if options.orig:
        do_orig(options)
    else:
        do(options, args)

def do(options, args):
    home = options.home
    revisions = ['0'] + options.revisions.split()
    source = version(options.source)
    if len(args) == 1:
        target = version(args[0])
    else:
        target = source

    if options.current is not None:
        vfile = version_file(options.current, True)
    else:
        vfile = version_file(source)
    current = vfile.version
    current_extra = vfile.extra

    target_extra = []
    if options.arch: target_extra.append(options.arch)
    if options.subarch: target_extra.append(options.subarch)
    if options.flavour: target_extra.append(options.flavour)
    target_extra = tuple(target_extra)

    if current.revision not in revisions:
        raise RuntimeError, "Current revision is not in our list of revisions"
    if target.revision not in revisions:
        raise RuntimeError, "Target revision is not in our list of revisions"

    if current.revision == target.revision and current_extra == target_extra:
        print "Nothing to do"
        return

    current_index = revisions.index(current.revision)
    source_index = revisions.index(source.revision)
    target_index = revisions.index(target.revision)

    if current_extra:
        if current_index != source_index:
            raise RuntimeError, "Can't patch from %s with options %s" % (current, ' '.join(current_extra))
        consider = revisions[current_index:0:-1]
        s = series_extra.read_all(consider, home, current_extra, reverse = True)
        vfile.begin()
        for i in s:
            i.apply()
        vfile.commit(current, ())

    if current_index < target_index:
        consider = revisions[current_index + 1:target_index + 1]
        s = series.read_all(consider, home)
        vfile.begin()
        for i in s:
            i.apply()
        vfile.commit(target, ())
    elif current_index > target_index:
        consider = revisions[current_index:target_index:-1]
        s = series.read_all(consider, home, reverse = True)
        vfile.begin()
        for i in s:
            i.apply()
        vfile.commit(target, ())

    if target_extra:
        consider = revisions[1:target_index + 1]
        s = series_extra.read_all(consider, home, target_extra)
        real_extra = ()
        for i in s:
            t = i.extra_used
            if len(t) > len(real_extra):
                real_extra = t
        vfile.begin()
        for i in s:
            i.apply()
        vfile.commit(target, real_extra)

def do_orig(options):
    series_orig(options.orig, options.home).apply()

def parse_options():
    from optparse import OptionParser
    parser = OptionParser(
        usage = "%prog [OPTION]... [TARGET]",
    )
    parser.add_option(
        '-a', '--arch',
        dest = 'arch',
        help = "arch",
    )
    parser.add_option(
        '-f', '--flavour',
        dest = 'flavour',
        help = "flavour",
    )
    parser.add_option(
        '-s', '--subarch',
        dest = 'subarch',
        help = "subarch",
    )
    parser.add_option(
        '-C', '--overwrite-current',
        dest = 'current',
        help = "overwrite current",
    )
    parser.add_option(
        '-H', '--overwrite-home',
        default = _default_home, dest = 'home',
        help = "overwrite home [default: %default]",
    )
    parser.add_option(
        '-R', '--overwrite-revisions',
        default = _default_revisions, dest = 'revisions',
        help = "overwrite revisions [default: %default]",
    )
    parser.add_option(
        '-S', '--overwrite-source',
        default = _default_source, dest = 'source',
        help = "overwrite source [default: %default]",
    )
    parser.add_option(
        '--orig',
        dest = 'orig',
    )

    options, args = parser.parse_args()

    if options.arch is None and options.subarch is not None:
        raise RuntimeError('You specified a subarch without an arch, this is not really working')
    if options.subarch is None and options.flavour is not None:
        raise RuntimeError('You specified a flavour without a subarch, this is not really working')

    return options, args

if __name__ == '__main__':
    def showwarning(message, category, filename, lineno):
        sys.stderr.write("Warning: %s\n" % message)
    import warnings
    warnings.showwarning = showwarning
    try:
        main()
    except RuntimeError, e:
        sys.stderr.write("Error: %s\n" % e)
        raise SystemExit, 1

