#!/usr/bin/env python3

# The following script requires Python 3.9 or higher

import argparse
import json
import subprocess
from typing import Callable, TypedDict, Literal, Union


class SpreadLogDetail(TypedDict):
    lines: list[str]


class SpreadLog_TypePhase(TypedDict):
    type: Literal["phase"]
    task: str
    verb: str # log_helper.ExecutionPhase
    detail: SpreadLogDetail


class SpreadLog_TypeResult(TypedDict):
    type: Literal["result"]
    result_type: str # log_helper.Result
    level: str # log_helper.ExecutionLevel
    stage: str
    detail: SpreadLogDetail


SpreadLog = Union[SpreadLog_TypePhase, SpreadLog_TypeResult]


def filter_with_spread(exec_param: list[str]) -> list[str]:
    cmd = ["spread", "-list"]
    cmd.extend(exec_param)
    return subprocess.check_output(cmd, text=True).splitlines()


def list_executed_tasks(
    filtered_exec_param: set[str], spread_logs: list[SpreadLog]
) -> set[str]:
    executed = [
        log["task"]
        for log in spread_logs
        if log["type"] == "phase" and log["verb"] == "Executing"
    ]
    return filtered_exec_param.intersection(executed)


def _get_detail_lines(
    spread_logs: list[SpreadLog], log_condition_func: Callable[[SpreadLog], bool]
) -> list[str]:
    result = [log["detail"]["lines"]
              for log in spread_logs if log_condition_func(log)]

    # Each entry in ['detail']['lines'] is a spread task prefaced with a '-' and
    # surrounding by whitespace
    return [log.strip().removeprefix('-').strip() for sublist in result for log in sublist]


def list_failed_tasks(
    filtered_exec_param: set[str], spread_logs: list[SpreadLog]
) -> set[str]:
    def log_condition(log: SpreadLog) -> bool:
        return (
            log["type"] == "result"
            and log["result_type"] == "Failed"
            and log["level"] == "tasks"
        )

    failed = _get_detail_lines(spread_logs, log_condition)
    return filtered_exec_param.intersection(failed)


def _list_failed(spread_logs: list[SpreadLog], level: str, stage: str) -> list[str]:
    def log_condition(log: SpreadLog) -> bool:
        return (
            log["type"] == "result"
            and log["result_type"] == "Failed"
            and log["level"] == level
            and log["stage"] == stage
        )

    return _get_detail_lines(spread_logs, log_condition)


def list_executed_and_failed(
    filtered_exec_param: set[str], spread_logs: list[SpreadLog]
) -> set[str]:
    failed = list_failed_tasks(filtered_exec_param, spread_logs)
    failed_prepare = _list_failed(spread_logs, "task", "prepare")
    failed_restore = _list_failed(spread_logs, "task", "restore")
    union = failed.union(failed_restore)
    return union.difference(failed_prepare)


def list_aborted_tasks(
    filtered_exec_param: set[str], spread_logs: list[SpreadLog]
) -> set[str]:
    executed_tasks = list_executed_tasks(filtered_exec_param, spread_logs)
    if len(executed_tasks) == 0:
        exec_and_failed = list_executed_and_failed(
            filtered_exec_param, spread_logs)
        return filtered_exec_param.difference(exec_and_failed)
    return filtered_exec_param.difference(executed_tasks)


def list_successful_tasks(
    filtered_exec_param: set[str], spread_logs: list[SpreadLog]
) -> set[str]:
    executed_tasks = list_executed_tasks(filtered_exec_param, spread_logs)
    failed = list_failed_tasks(filtered_exec_param, spread_logs)
    failed_restore = _list_failed(spread_logs, "task", "restore")
    failed = failed.union(failed_restore)
    if len(failed) > 0:
        executed_tasks = executed_tasks.difference(failed)
    return executed_tasks


def list_rexecute_tasks(
    exec_param: str, filtered_exec_param: set[str], spread_logs: list[SpreadLog]
) -> set[str]:
    aborted_tasks = list_aborted_tasks(filtered_exec_param, spread_logs)
    exec_and_failed = list_executed_and_failed(
        filtered_exec_param, spread_logs)
    exec_and_failed = exec_and_failed.intersection(filtered_exec_param)
    union = aborted_tasks.union(exec_and_failed)
    if len(filtered_exec_param.difference(union)) == 0:
        return set(exec_param)
    return union


def add_arguments(parser: argparse.ArgumentParser) -> None:
    parser.add_argument(
        "exec_params",
        help="This is the parameter used to run spread (something like this BACKEND:SYSTEM:SUITE)",
    )
    parser.add_argument(
        "parsed_log",
        type=argparse.FileType("r", encoding="utf-8"),
        help="This is the output generated by the log-parser tool",
    )


def main() -> None:
    parser = argparse.ArgumentParser(
        description="""
        Usage: log-analyzer list-failed-tasks <EXEC-PARAM> <PARSED-LOG>
               log-analyzer list-executed-tasks <EXEC-PARAM> <PARSED-LOG>
               log-analyzer list-successful-tasks <EXEC-PARAM> <PARSED-LOG>
               log-analyzer list-aborted-tasks <EXEC-PARAM> <PARSED-LOG>
               log-analyzer list-all-tasks <EXEC-PARAM>
               log-analyzer list-reexecute-tasks <EXEC-PARAM> <PARSED-LOG>

        The log analyzer is an utility that provides useful information about a spread
        execution. The main functionality of the analyzer utility is to determine which tests
        have to be re-executed, including aborted tests that are not included in the test results.
        The log analyzer uses as input the spread expression that was used to run the tests.
        This expression determines which tests to considered. The second input is the output of
        the log-parser utility, which generates a json file including all the information
        extracted from the raw spread log.
                                     """,
        formatter_class=argparse.RawDescriptionHelpFormatter,
    )

    subparsers = parser.add_subparsers(dest="command")
    subparsers.required = True
    failed = subparsers.add_parser(
        "list-failed-tasks", help="list the tasks that failed during execute"
    )
    executed = subparsers.add_parser(
        "list-executed-tasks", help="list the tasks that were executed"
    )
    successful = subparsers.add_parser(
        "list-successful-tasks", help="list the successful tasks"
    )
    aborted = subparsers.add_parser(
        "list-aborted-tasks",
        help="list the aborted tasks (needs spread to be installed)",
    )
    reexecute = subparsers.add_parser(
        "list-reexecute-tasks",
        help="list the tasks to re-execute to complete (includes aborted and failed tasks)",
    )
    list_all = subparsers.add_parser(
        "list-all-tasks", help="list all the tasks")

    add_arguments(failed)
    add_arguments(executed)
    add_arguments(successful)
    add_arguments(aborted)
    add_arguments(reexecute)
    list_all.add_argument(
        "exec_params",
        help="This is the parameter used to run spread (something like this BACKEND:SYSTEM:SUITE)",
    )

    args = parser.parse_args()

    exec_params = args.exec_params.replace(",", " ").split()
    filtered_exec_param = set(filter_with_spread(exec_params))

    if hasattr(args, "parsed_log"):
        log = json.load(args.parsed_log)
        if not log:
            raise RuntimeError("log.analyzer: the log file cannot be empty")

    if args.command == "list-failed-tasks":
        print(" ".join(list_failed_tasks(filtered_exec_param, log)))
    elif args.command == "list-executed-tasks":
        print(" ".join(list_executed_tasks(filtered_exec_param, log)))
    elif args.command == "list-successful-tasks":
        print(" ".join(list_successful_tasks(filtered_exec_param, log)))
    elif args.command == "list-aborted-tasks":
        print(" ".join(list_aborted_tasks(filtered_exec_param, log)))
    elif args.command == "list-all-tasks":
        print(" ".join(filtered_exec_param))
    elif args.command == "list-reexecute-tasks":
        print(" ".join(list_rexecute_tasks(
            exec_params, filtered_exec_param, log)))
    else:
        raise RuntimeError("log.analyzer: no such command: %s" % args.command)


if __name__ == "__main__":
    main()
