#!/usr/bin/env python3
# Copyright lowRISC contributors.
# Licensed under the Apache License, Version 2.0, see LICENSE for details.
# SPDX-License-Identifier: Apache-2.0

'''Wrapper script to run the OTBN random instruction generator'''

import argparse
import json
import os
import random
import sys
from typing import Optional, cast

# Ensure that the OTBN utils directory is on sys.path. This means that RIG code
# can import modules like "shared.foo" and get the OTBN shared code.
_RIG_DIR = os.path.dirname(__file__)
_OTBN_DIR = os.path.normpath(os.path.join(_RIG_DIR, '../..'))
_UTIL_DIR = os.path.join(_OTBN_DIR, 'util')
sys.path.append(_UTIL_DIR)

from shared.insn_yaml import InsnsFile, load_insns_yaml  # noqa: E402

from rig.config import Config  # noqa: E402
from rig.init_data import InitData  # noqa: E402
from rig.rig import gen_program  # noqa: E402
from rig.snippet import Snippet  # noqa: E402


def get_insns_file() -> Optional[InsnsFile]:
    '''Load up insns.yml'''
    try:
        return load_insns_yaml()
    except RuntimeError as err:
        print(err, file=sys.stderr)
        return None


def get_named_config(name: str) -> Optional[Config]:
    cfg_dir = os.path.join(_RIG_DIR, 'rig/configs')
    try:
        return Config.load(cfg_dir, name)
    except RuntimeError as err:
        print(err, file=sys.stderr)
        return None


def gen_main(args: argparse.Namespace) -> int:
    '''Entry point for the gen subcommand'''
    random.seed(args.seed)

    insns_file = get_insns_file()
    if insns_file is None:
        return 1

    config = get_named_config(args.config)
    if config is None:
        return 1

    # Run the generator
    try:
        init_data, snippet = gen_program(config,
                                         args.start_addr,
                                         args.size,
                                         insns_file)
    except RuntimeError as err:
        print(err, file=sys.stderr)
        return 1

    # Write out the data and snippets to a JSON file
    ser_data = init_data.as_json()
    ser_snippet = snippet.to_json()
    ser = [ser_data, ser_snippet]
    json.dump(ser, sys.stdout)
    # Add a newline at end of output: json.dump doesn't, and it makes a
    # bit of a mess of some consoles.
    sys.stdout.write('\n')

    return 0


def asm_main(args: argparse.Namespace) -> int:
    '''Entry point for the asm subcommand'''

    insns_file = get_insns_file()
    if insns_file is None:
        return 1

    # Load JSON
    try:
        json_data = json.load(args.snippets)
    except json.JSONDecodeError as err:
        print('Snippets file at {!r} is not valid JSON: {}.'
              .format(args.snippets.name, err),
              file=sys.stderr)
        return 1

    # Parse these to proper init data and snippet objects.
    try:
        if not (isinstance(json_data, list) and len(json_data) == 2):
            raise ValueError('Top-level structure should be a length 2 list.')

        json_init_data, json_snippet = json_data
        init_data = InitData.read(json_init_data)
        snippet = Snippet.from_json(insns_file, [], json_snippet)
    except ValueError as err:
        print('Failed to parse snippets from {!r}: {}'
              .format(args.snippets.name, err),
              file=sys.stderr)
        return 1

    program = snippet.to_program()
    dsegs = init_data.as_segs()

    # Dump the assembly output, and the linker script too if we're writing to
    # something other than stdout.
    if args.output is None or args.output == '-':
        program.dump_asm(sys.stdout, dsegs)
    else:
        try:
            asm_path = args.output + '.s'
            with open(asm_path, 'w') as out_file:
                program.dump_asm(out_file, dsegs)
        except OSError as err:
            print('Failed to open asm output file {!r}: {}.'
                  .format(args.output, err),
                  file=sys.stderr)
            return 1

        try:
            ld_path = args.output + '.ld'
            with open(ld_path, 'w') as out_file:
                program.dump_linker_script(out_file, dsegs)
        except OSError as err:
            print('Failed to open ld script output file {!r}: {}.'
                  .format(ld_path, err),
                  file=sys.stderr)
            return 1

    return 0


def main() -> int:
    parser = argparse.ArgumentParser()
    subparsers = parser.add_subparsers(dest='cmd')
    subparsers.required = True

    gen = subparsers.add_parser('gen', help='Generate a random program')
    asm = subparsers.add_parser('asm', help='Convert snippets to assembly')

    gen.add_argument('--seed', type=int, default=0,
                     help='Random seed. Defaults to 0.')
    gen.add_argument('--size', type=int, default=100,
                     help=('Max number of instructions in stream. '
                           'Defaults to 100.'))
    gen.add_argument('--start-addr', type=int, default=0,
                     help='Reset address. Defaults to 0.')
    gen.add_argument('--config', type=str, default='default',
                     help='Configuration to use')
    gen.set_defaults(func=gen_main)

    asm.add_argument('--output', '-o',
                     metavar='out',
                     help=('Base path for output filenames. Will generate '
                           'out.s with an assembly listing and out.ld with '
                           'a linker script. If this is not supplied, the '
                           'assembly (but no linker script) will be dumped '
                           'to stdout.'))
    asm.add_argument('snippets', metavar='path.json',
                     type=argparse.FileType('r'), nargs='?', default=sys.stdin,
                     help=('A JSON file of snippets, as generated by '
                           'otbn-rig gen.'))
    asm.set_defaults(func=asm_main)

    args = parser.parse_args()
    return cast(int, args.func(args))


if __name__ == '__main__':
    sys.exit(main())
