#! /usr/bin/python3
# -*- coding: utf-8 -*-

# Copyright (C) 2012-2014 by László Nagy
# This file is part of Bear.
#
# Bear is a tool to generate compilation database for clang tooling.
#
# Bear is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# Bear is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
""" This module is responsible to capture the compiler invocation of any
build process. The result of that should be a compilation database.

This implementation is using the LD_PRELOAD or DYLD_INSERT_LIBRARIES
mechanisms provided by the dynamic linker. The related library is implemented
in C language and can be found under 'libear' directory.

The 'libear' library is capturing all child process creation and logging the
relevant information about it into separate files in a specified directory.
The input of the library is therefore the output directory which is passed
as an environment variable.

This module implements the build command execution with the 'libear' library
and the post-processing of the output files, which will condensates into a
(might be empty) compilation database. """

import argparse
import collections
import subprocess
import json
import sys
import os
import os.path
import re
import shlex
import itertools
import tempfile
import shutil
import contextlib
import logging

# Ignored compiler options map for compilation database creation.
# The map is used in `split_command` method. (Which does ignore and classify
# parameters.) Please note, that these are not the only parameters which
# might be ignored.
#
# Keys are the option name, value number of options to skip
IGNORED_FLAGS = {
    # compiling only flag, ignored because the creator of compilation
    # database will explicitly set it.
    '-c': 0,
    # preprocessor macros, ignored because would cause duplicate entries in
    # the output (the only difference would be these flags). this is actual
    # finding from users, who suffered longer execution time caused by the
    # duplicates.
    '-MD': 0,
    '-MMD': 0,
    '-MG': 0,
    '-MP': 0,
    '-MF': 1,
    '-MT': 1,
    '-MQ': 1,
    # linker options, ignored because for compilation database will contain
    # compilation commands only. so, the compiler would ignore these flags
    # anyway. the benefit to get rid of them is to make the output more
    # readable.
    '-static': 0,
    '-shared': 0,
    '-s': 0,
    '-rdynamic': 0,
    '-l': 1,
    '-L': 1,
    '-u': 1,
    '-z': 1,
    '-T': 1,
    '-Xlinker': 1
}

# Known C/C++ compiler wrapper name patterns
COMPILER_WRAPPER_PATTERN = re.compile(r'^(distcc|ccache)$')

# Known C/C++ compiler executable name patterns
COMPILER_PATTERNS = frozenset([
    re.compile(r'^(cc|c\+\+|cxx|CC)$'),
    re.compile(r'^([^-]*-)*[mg](cc|\+\+)(-\d+(\.\d+){0,2})?$'),
    re.compile(r'^([^-]*-)*clang(\+\+)?(-\d+(\.\d+){0,2})?$'),
    re.compile(r'^llvm-g(cc|\+\+)$'),
    re.compile(r'^i(cc|cpc)$'),
    re.compile(r'^(g|)xl(c|C|c\+\+)$'),
])

# Known C++ compiler executable name patterns
COMPILER_CPP_PATTERNS = frozenset([
    re.compile(r'^(.+)(\+\+)(-.+|)$'),  # C++ compilers usually ends with '++'
    re.compile(r'^(icpc|xlC|cxx|CC)$'),
])


def main():
    try:
        parser = create_parser()
        args = parser.parse_args()

        level = logging.WARNING - min(logging.WARNING, (10 * args.verbose))
        logging.basicConfig(format='bear: %(message)s', level=level)
        logging.debug('version: 2.2.1')
        logging.debug(args)

        if not args.build:
            parser.print_help()
            return 0

        return capture(args)
    except KeyboardInterrupt:
        return 1
    except Exception:
        logging.exception('Internal error.')
        if logging.getLogger().isEnabledFor(logging.DEBUG):
            logging.error("Please report this bug and attach the output "
                          "to the bug report")
        else:
            logging.error("Please run this command again and turn on "
                          "verbose mode (add '-vvvv' as argument).")
        return 127
    finally:
        logging.shutdown()


def capture(args):
    """ The entry point of build command interception. """

    @contextlib.contextmanager
    def temporary_directory(**kwargs):
        name = tempfile.mkdtemp(**kwargs)
        try:
            yield name
        finally:
            shutil.rmtree(name)

    def post_processing(commands):
        """ To make a compilation database, it needs to filter out commands
        which are not compiler calls. Needs to find the source file name
        from the arguments. And do shell escaping on the command.

        To support incremental builds, it is desired to read elements from
        an existing compilation database from a previous run. These elements
        shall be merged with the new elements. """

        # create entries from the current run
        current = itertools.chain.from_iterable(
            # creates a sequence of entry generators from an exec,
            format_entry(command) for command in commands)
        # read entries from previous run
        if args.append and os.path.isfile(args.cdb):
            with open(args.cdb) as handle:
                previous = iter(json.load(handle))
        else:
            previous = iter([])
        # filter out duplicate entries from both
        duplicate = duplicate_check(entry_hash)
        return (entry
                for entry in itertools.chain(previous, current)
                if os.path.exists(entry['file']) and not duplicate(entry))

    def run_build(command, environment):
        try:
            logging.info('run build in environment: %s', environment)
            exit_code = subprocess.call(command, env=environment)
            logging.info('build finished with exit code: %s', exit_code)
            return exit_code
        except KeyboardInterrupt:
            logging.warning('build interupted by user. (writing output)')
            return 1

    with temporary_directory(prefix='bear-') as tmp_dir:
        # run the build command
        environment = setup_environment(tmp_dir, args.libear)
        exit_code = run_build(args.build, environment)
        # read the intercepted exec calls
        exec_traces = itertools.chain.from_iterable(
            parse_exec_trace(os.path.join(tmp_dir, filename))
            for filename in sorted(os.listdir(tmp_dir)))
        # do post processing only if that was requested
        if not args.raw_entries:
            entries = post_processing(exec_traces)
        else:
            entries = exec_traces
        # dump the compilation database
        with open(args.cdb, 'w+') as handle:
            json.dump(list(entries), handle, sort_keys=True, indent=4)
        return exit_code


def setup_environment(destination, ear_library_path):
    """ Sets up the environment for the build command.

    It sets the required environment variables and execute the given command.
    The exec calls will be logged by the 'libear' preloaded library or by the
    'wrapper' programs. """

    environment = dict(os.environ)
    environment.update({'BEAR_OUTPUT': destination})

    if sys.platform == 'darwin':
        environment.update({
            'DYLD_INSERT_LIBRARIES': ear_library_path,
            'DYLD_FORCE_FLAT_NAMESPACE': '1'
        })
    else:
        environment.update({'LD_PRELOAD': ear_library_path})

    return environment


def parse_exec_trace(filename):
    """ Parse the file generated by the 'libear' preloaded library.

    Given filename points to a file which contains the basic report
    generated by the interception library or wrapper command. A single
    report file _might_ contain multiple process creation info. """

    GS = chr(0x1d)
    RS = chr(0x1e)
    US = chr(0x1f)

    logging.debug('parse exec trace file: %s', filename)
    with open(filename, 'r') as handler:
        content = handler.read()
        for group in filter(bool, content.split(GS)):
            records = group.split(RS)
            yield {
                'pid': records[0],
                'ppid': records[1],
                'function': records[2],
                'directory': records[3],
                'command': records[4].split(US)[:-1]
            }


def format_entry(exec_trace):
    """ Generate the desired fields for compilation database entries. """

    def abspath(cwd, name):
        """ Create normalized absolute path from input filename. """
        fullname = name if os.path.isabs(name) else os.path.join(cwd, name)
        return os.path.normpath(fullname)

    logging.debug('format this command: %s', exec_trace['command'])
    compilation = split_command(exec_trace['command'])
    if compilation:
        for source in compilation.files:
            compiler = 'c++' if compilation.compiler == 'c++' else 'cc'
            command = [compiler, '-c'] + compilation.flags + [source]
            logging.debug('formated as: %s', command)
            yield {
                'directory': exec_trace['directory'],
                'command': shell_escape(command),
                'file': abspath(exec_trace['directory'], source)
            }


def shell_escape(command):
    """ Takes a command as list and returns a string. """

    def needs_quote(word):
        """ Returns true if arguments needs to be protected by quotes.

        Previous implementation was shlex.split method, but that's not good
        for this job. Currently is running through the string with a basic
        state checking. """

        reserved = {' ', '$', '%', '&', '(', ')', '[', ']', '{', '}', '*', '|',
                    '<', '>', '@', '?', '!'}
        state = 0
        for current in word:
            if state == 0 and current in reserved:
                return True
            elif state == 0 and current == '\\':
                state = 1
            elif state == 1 and current in reserved | {'\\'}:
                state = 0
            elif state == 0 and current == '"':
                state = 2
            elif state == 2 and current == '"':
                state = 0
            elif state == 0 and current == "'":
                state = 3
            elif state == 3 and current == "'":
                state = 0
        return state != 0

    def escape(word):
        """ Do protect argument if that's needed. """

        table = {'\\': '\\\\', '"': '\\"', ' ': '\\ '}
        escaped = ''.join([table.get(c, c) for c in word])

        return '"' + escaped + '"' if needs_quote(word) else escaped

    return " ".join([escape(arg) for arg in command])


def entry_hash(entry):
    """ Implement unique hash method for compilation database entries. """

    # For faster lookup in set filename is reverted
    filename = entry['file'][::-1]
    # For faster lookup in set directory is reverted
    directory = entry['directory'][::-1]
    # On OS X the 'cc' and 'c++' compilers are wrappers for
    # 'clang' therefore both call would be logged. To avoid
    # this the hash does not contain the first word of the
    # command.
    command = ' '.join(shlex.split(entry['command'])[1:])

    return '<>'.join([filename, directory, command])


def duplicate_check(hash_function):
    """ Workaround to detect duplicate dictionary values.

    Python `dict` type has no `hash` method, which is required by the `set`
    type to store elements.

    This solution still not store the `dict` as value in a `set`. Instead
    it calculate a `string` hash and store that. Therefore it can only say
    that hash is already taken or not.

    This method is a factory method, which returns a predicate. """

    def predicate(entry):
        """ The predicate which calculates and stores the hash of the given
        entries. The entry type has to work with the given hash function.

        :param entry: the questioned entry,
        :return: true/false depends the hash value is already seen or not.
        """
        entry_hash = hash_function(entry)
        if entry_hash not in state:
            state.add(entry_hash)
            return False
        return True

    state = set()
    return predicate


def create_parser():
    """ Parser factory method. """

    parser = argparse.ArgumentParser(
        formatter_class=argparse.ArgumentDefaultsHelpFormatter)

    parser.add_argument(
        '--version',
        action='version',
        version='%(prog)s 2.2.1')
    parser.add_argument(
        '--verbose', '-v',
        action='count',
        default=0,
        help="""enable verbose output from '%(prog)s'. A second '-v' increases
                verbosity.""")
    parser.add_argument(
        '--cdb', '-o',
        metavar='<file>',
        default="compile_commands.json",
        help="""The JSON compilation database.""")
    parser.add_argument(
        '--append', '-a',
        action='store_true',
        help="""appends new entries to existing compilation database.""")

    testing = parser.add_argument_group('advanced options')
    testing.add_argument(
        '--disable-filter', '-n',
        dest='raw_entries',
        action='store_true',
        help="""disable filter, unformatted output.""")
    testing.add_argument(
        '--libear', '-l',
        dest='libear',
        default="/usr/lib/powerpc64le-linux-gnu/bear/libear.so",
        action='store',
        help="""specify libear file location.""")

    parser.add_argument(
        dest='build',
        nargs=argparse.REMAINDER,
        help="""command to run.""")

    return parser


def split_command(command):
    """ Returns a value when the command is a compilation, None otherwise.

    The value on success is a named tuple with the following attributes:

        files:    list of source files
        flags:    list of compile options
        compiler: string value of 'c' or 'c++' """

    # quit right now, if the program was not a C/C++ compiler
    compiler_and_arguments = split_compiler(command)
    if compiler_and_arguments is None:
        return None

    # the result of this method
    result = collections.namedtuple('Compilation',
                                    ['compiler', 'flags', 'files'])
    result.compiler = compiler_and_arguments[0]
    result.flags = []
    result.files = []
    # iterate on the compile options
    args = iter(compiler_and_arguments[1])
    for arg in args:
        # quit when compilation pass is not involved
        if arg in {'-E', '-S', '-cc1', '-M', '-MM', '-###'}:
            return None
        # ignore some flags
        elif arg in IGNORED_FLAGS:
            count = IGNORED_FLAGS[arg]
            for _ in range(count):
                next(args)
        elif re.match(r'^-(l|L|Wl,).+', arg):
            pass
        # some parameters could look like filename, take as compile option
        elif arg in {'-D', '-I'}:
            result.flags.extend([arg, next(args)])
        # parameter which looks source file is taken...
        elif re.match(r'^[^-].+', arg) and is_source(arg):
            result.files.append(arg)
        # and consider everything else as compile option.
        else:
            result.flags.append(arg)
    # do extra check on number of source files
    return result if result.files else None


def split_compiler(command):
    """ A predicate to decide the command is a compiler call or not.

    :param command: the command to classify
    :return:        None if the command is not a compilation
                    (compiler_language, rest of the command) tuple if the
                    command is a compilation. """

    def is_wrapper(candidate):
        return True if COMPILER_WRAPPER_PATTERN.match(candidate) else False

    def is_compiler(candidate):
        return any(pattern.match(candidate) for pattern in COMPILER_PATTERNS)

    def is_cplusplus(candidate):
        return any(pattern.match(candidate)
                   for pattern in COMPILER_CPP_PATTERNS)

    if command:  # not empty list will allow to index '0' and '1:'
        executable = os.path.basename(command[0])
        parameters = command[1:]
        # 'wrapper' 'parameters' and
        # 'wrapper' 'compiler' 'parameters' are valid.
        # plus, a wrapper can wrap wrapper too.
        if is_wrapper(executable):
            result = split_compiler(parameters)
            return ('c', parameters) if result is None else result
        # and 'compiler' 'parameters' is valid.
        elif is_compiler(executable):
            language = 'c++' if is_cplusplus(executable) else 'c'
            return language, parameters
    return None


def is_source(filename):
    """ A predicate to decide the filename is a source file or not. """

    accepted = [
        '.c', '.cc', '.cp', '.cpp', '.cxx', '.c++', '.m', '.mm', '.i', '.ii',
        '.mii'
    ]
    __, extension = os.path.splitext(os.path.basename(filename))
    return extension.lower() in accepted


if __name__ == "__main__":
    sys.exit(main())
