#!/usr/bin/env python3
# Copyright lowRISC contributors.
# Licensed under the Apache License, Version 2.0, see LICENSE for details.
# SPDX-License-Identifier: Apache-2.0

'''A wrapper around riscv32-unknown-elf-objdump for OTBN'''

import os
import re
import subprocess
import sys
from typing import Dict, List, Optional, Tuple

from shared.encoding import Encoding
from shared.insn_yaml import Insn, InsnsFile, load_insns_yaml
from shared.toolchain import find_tool


def snoop_disasm_flags(argv: List[str]) -> bool:
    '''Look through objdump's flags for -d, -D etc.'''
    for arg in argv:
        if arg in ['-d', '-D', '--disassemble', '--disassemble-all']:
            return True

        # --disassemble=symbol
        if arg.startswith('--disassemble='):
            return True

    return False


def get_insn(raw: int, masks: List[Tuple[int, int, Insn]]) -> Optional[Insn]:
    '''Try to find a mnemonic for this raw instruction

    masks is a list of tuples (m0, m1, mnemonic) as returned by
    get_insn_masks. If no tuple matches, returns None.

    '''
    found = None
    for m0, m1, insn in masks:
        # If any bit is set that should be zero or if any bit is clear that
        # should be one, ignore this instruction.
        if raw & m0 or (~ raw) & m1:
            continue

        # We have a match! The code in insn_yaml should already have checked
        # this is the only one, but it can't hurt to be careful.
        assert found is None

        found = insn

    return found


# OTBN instructions are 32 bit wide, so there's just one "word" in the second
# column. The stuff that gets passed through looks like this:
#
#    84:   8006640b                0x8006640b
#
# We don't use a back-ref for the second copy of the data, because if the raw
# part has leading zeros, they don't appear there. For example:
#
#   6d0:   0000418b                0x418b
#
_RAW_INSN_RE = re.compile(r'([\s]*([0-9a-f]+):[\s]+([0-9a-f]{8})[\s]+)'
                          r'0x[0-9a-f]+\s*$')


def transform_disasm_line(line: str,
                          masks: List[Tuple[int, int, Insn]]) -> str:
    '''Transform filter to insert OTBN disasm as needed'''
    match = _RAW_INSN_RE.match(line)
    if match is None:
        return line

    # match.group(3) is the raw instruction word. Parse it as an integer. It
    # was exactly 8 hex characters, so will fit in a u32.
    raw = int(match.group(3), 16)
    assert 0 <= raw < (1 << 32)

    insn = get_insn(raw, masks)
    if insn is None:
        # No match for this instruction pattern. Leave as-is.
        return line

    # match.group(2) is the PC in hex.
    pc = int(match.group(2), 16)

    # Extract the encoded values. We know we have an encoding (otherwise
    # get_insn_masks wouldn't have added the instruction to the masks list).
    assert insn.encoding is not None
    enc_vals = insn.encoding.extract_operands(raw)

    # Make sense of these encoded values as "operand values" (doing any
    # shifting, sign interpretation etc.)
    op_vals = insn.enc_vals_to_op_vals(pc, enc_vals)

    # Similarly, we know we have a syntax (again, get_insn_masks requires it).
    # The rendering of the fields is done by the syntax object.
    return match.group(1) + insn.disassemble(pc, op_vals)


def get_insn_masks(insns_file: InsnsFile) -> List[Tuple[int, int, Insn]]:
    '''Generate a list of zeros/ones masks for known instructions

    The returned list has elements (m0, m1, mnemonic). We don't check here that
    the results are unambiguous: that check is supposed to happen in insn_yaml
    already, and we'll do a belt-and-braces check for each instruction as we
    go.

    '''
    ret = []
    for insn in insns_file.insns:
        if insn.encoding is None:
            continue

        m0, m1 = insn.encoding.get_masks()
        # Encoding.get_masks sets bits that are 'x', so we have to do a
        # difference operation too.
        ret.append((m0 & ~m1, m1 & ~m0, insn))
    return ret


def main() -> int:
    args = sys.argv[1:]
    has_disasm = snoop_disasm_flags(args)

    objdump = find_tool('objdump')
    try:
        if not has_disasm:
            cmd = [objdump] + args
            return subprocess.run(cmd).returncode
        else:
            cmd = [objdump, '-M', 'numeric,no-aliases'] + args
            proc = subprocess.run(cmd,
                                  stdout=subprocess.PIPE,
                                  stderr=subprocess.PIPE,
                                  universal_newlines=True)
            if proc.returncode:
                # Dump any lines that objdump wrote before it died
                sys.stderr.write(proc.stderr)
                sys.stdout.write(proc.stdout)
                return proc.returncode
    except FileNotFoundError:
        sys.stderr.write('Unknown command: {!r}.\n'
                         .format(objdump))
        return 127

    try:
        insns_file = load_insns_yaml()
    except RuntimeError as err:
        sys.stderr.write('{}\n'.format(err))
        return 1

    insn_masks = get_insn_masks(insns_file)

    # If we get here, we think we're disassembling something, objdump ran
    # successfully and we have its results in proc.stdout
    for line in proc.stdout.split('\n'):
        transformed = transform_disasm_line(line, insn_masks)
        sys.stdout.write(transformed + '\n')

    return 0


if __name__ == '__main__':
    sys.exit(main())
