# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
#
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
# property and proprietary rights in and to this material, related
# documentation and any modifications thereto. Any use, reproduction,
# disclosure or distribution of this material and related documentation
# without an express license agreement from NVIDIA CORPORATION or
# its affiliates is strictly prohibited.

import ast

import numpy as np
import pandas as pd

from nsys_recipe.lib import arm_metrics as am
from nsys_recipe.lib import exceptions, overlap


class Architecture:
    UNKNOWN = "unknown"
    AARCH64_SBSA = "aarch64_sbsa"
    AARCH64_TEGRA = "aarch64_tegra"
    X86_64 = "x86_64"


THREAD_STATE_TERMINATED = 5
THREAD_STATE_UNKNOWN = 0


class SchedEventError(exceptions.ValueError):
    pass


def _get_cpu_activity_ranges(thread_sched_df):
    """
    Get CPU activity ranges.

    This function returns an array of CPU time ranges, where such a range means that
    in a particular CPU, a particular thread from a particular process was scheduled
    for that time.

    The input DataFrame `thread_sched_df` represents the SQL
    thread scheduling events table `SCHED_EVENTS`.
    The output DataFrame will contain CPU activity ranges
    (cpu, start, end, tid, pid).
    """

    def get_missing_event_msg(row, sched_in_missing):
        return (
            f"The scheduling {'in' if sched_in_missing else 'out'} "
            + f"event is missing for the thread "
            + f"with TID {row.tid} and PID {row.pid} on CPU {row.cpu}, "
            + f"the timestamp of the scheduling "
            + f"{'out' if sched_in_missing else 'in'} event: {row.start}."
        )

    def get_missing_sched_in_event_msg(row):
        return get_missing_event_msg(row, True)

    def get_missing_sched_out_event_msg(row):
        return get_missing_event_msg(row, False)

    def get_cpu_mismatch_msg(prev_row, row):
        return (
            f"The scheduling out event occured on a different CPU "
            + f"for the thread with TID {prev_row.tid} and PID {prev_row.pid}. "
            + f"The thread was scheduled in on CPU {prev_row.cpu} "
            + f"scheduled out on CPU {row.cpu}, "
            + f"scheduling in time: {prev_row.start}, scheduling out time: {row.start}."
        )

    thread_sched_df = thread_sched_df.sort_values(by=["globalTid", "start"])
    thread_sched_rows = thread_sched_df.itertuples()
    prev_row = next(thread_sched_rows)

    # We are skipping the situation where the first event for the thread is
    # a scheduling-out event, as it is possible that the thread was scheduled
    # in before the profiling was started.
    if not prev_row.isSchedIn:
        prev_row = next(thread_sched_rows)

    cpu_ranges = []
    for row in thread_sched_rows:
        if prev_row.globalTid == row.globalTid:
            if not prev_row.isSchedIn:
                if not row.isSchedIn:
                    # NSys can generate two scheduling-out events in a row
                    # for the same cpu when the target thread terminates.
                    # These will be two scheduling-out events with different thread
                    # states: the first with the TERMINATED thread state
                    # and the second with the UNKNOWN thread state.
                    # In such a case we ignore the second event and continue
                    # processing the next one.
                    if not (
                        prev_row.cpu == row.cpu
                        and prev_row.threadState == THREAD_STATE_TERMINATED
                        and row.threadState == THREAD_STATE_UNKNOWN
                    ):
                        raise SchedEventError(get_missing_sched_in_event_msg(row))
            else:
                if row.isSchedIn:
                    raise SchedEventError(get_missing_sched_out_event_msg(prev_row))

                if prev_row.cpu != row.cpu:
                    raise SchedEventError(get_cpu_mismatch_msg(prev_row, row))

                cpu_ranges.append(
                    {
                        "cpu": prev_row.cpu,
                        "start": prev_row.start,
                        "end": row.start,
                        "tid": prev_row.tid,
                        "pid": prev_row.pid,
                    }
                )
        # else:
        #     The case where `prev_row` and `row` are related to different threads.
        #     This means that `prev_row` is the last scheduling event for its thread,
        #     and `row` is the first scheduling event for its thread.
        #     Note, that we don't throw errors for the following situations,
        #     we simply skip these unexpected thread scheduling events:
        #     1. The last event for the thread is a scheduling-in event, as it is
        #        possible that the thread was scheduled out after the profiling was
        #        stopped.
        #     2. The first event for the thread is a scheduling-out event, as it is
        #        possible that the thread was scheduled in before the profiling was
        #        started.

        prev_row = row

    cpu_df = (
        pd.DataFrame(cpu_ranges).sort_values(by=["cpu", "start"]).reset_index(drop=True)
    )
    return cpu_df


def _get_perf_event_scaler_for_range(
    perf_event_start, perf_event_end, range_start, range_end
):
    """
    Get perf event scaler for a range.
    In case the perf event range is partially inside the range of interest,
    the scaler will be calculated based on the overlapping time.
    """
    count_scaler = 0
    if perf_event_start >= range_start:
        if perf_event_end <= range_end:
            # |--------------------------------------------|
            #       |-----------Range-----------|
            #               |--Perf Event--|
            count_scaler = 1
        elif perf_event_start < range_end:
            # |--------------------------------------------|
            #       |-----------Range-----------|
            #                              |--Perf Event--|
            perf_event_time = perf_event_end - perf_event_start
            perf_event_in_range_time = range_end - perf_event_start
            count_scaler = perf_event_in_range_time / perf_event_time
    elif perf_event_end > range_start:
        perf_event_time = perf_event_end - perf_event_start
        if perf_event_end <= range_end:
            # |--------------------------------------------|
            #       |-----------Range-----------|
            # |--Perf Event--|
            perf_event_in_range_time = perf_event_end - range_start
            count_scaler = perf_event_in_range_time / perf_event_time
        else:
            # |--------------------------------------------|
            #                  |-Range-|
            #               |--Perf Event--|
            range_time = range_end - range_start
            count_scaler = range_time / perf_event_time

    return count_scaler


def _compute_core_perf_events(ranges_df, core_perf_df, cpu_df, rely_on_tid):
    """
    Compute core perf events and their number of samples for the given ranges.

    If `rely_on_tid` is set to True, the function will only consider events
    that occurred on CPUs on which the thread of a particular range was running.

    If `rely_on_tid` is set to False, the function will only consider events
    that occurred on CPUs on which the process of a particular range was running
    (data from all threads active while the range was running will be included).

    The output tuple will contain two DataFrames:
    (core perf events, their number of samples)
    collected by nsys with a row corresponding to the range.
    """
    perf_events = []
    perf_samples = []
    perf_event_names = core_perf_df["name"].unique()

    for range in ranges_df.itertuples():
        # 1. Filter perf events for the range of interest by time overlap.
        range_perf_df = core_perf_df[
            (core_perf_df["start"] <= range.end) & (core_perf_df["end"] >= range.start)
        ]

        # 2. Attribute CPU activity ranges to the range of interest by time overlap,
        #    PID and TID (if rely_on_tid=True) match.
        if rely_on_tid:
            range_cpu_df = cpu_df[
                (cpu_df["tid"] == range.tid)
                & (cpu_df["pid"] == range.pid)
                & (cpu_df["start"] <= range.end)
                & (cpu_df["end"] >= range.start)
            ]
        else:
            range_cpu_df = cpu_df[
                (cpu_df["pid"] == range.pid)
                & (cpu_df["start"] <= range.end)
                & (cpu_df["end"] >= range.start)
            ]

        # 3. Filter perf events by CPU activity ranges using map_overlapping_ranges().
        range_perf_gdf = range_perf_df.groupby("cpu")
        range_cpu_gdf = range_cpu_df.groupby("cpu")

        range_perf_indices = []
        range_cpu_indices = []
        for cpu, curr_cpu_df in range_cpu_gdf:
            if cpu not in range_perf_gdf.groups:
                continue

            curr_perf_df = range_perf_gdf.get_group(cpu)

            perf_to_cpu_idx_map = overlap.map_overlapping_ranges(
                curr_perf_df, curr_cpu_df, key_df="df1"
            )

            for perf_idx, cpu_indices in perf_to_cpu_idx_map.items():
                range_perf_indices.extend([perf_idx] * len(cpu_indices))
                range_cpu_indices.extend(cpu_indices)

        range_perf_df = range_perf_df.loc[range_perf_indices]
        range_cpu_df = range_cpu_df.loc[range_cpu_indices]

        range_perf_events = {x: 0 for x in perf_event_names}
        range_perf_samples = {
            f"{x}_{y}": 0 for x in perf_event_names for y in range_cpu_df["cpu"].values
        }

        # Here the `_get_perf_event_scaler_for_range()` is called for each
        # portion of the `range` that happens on a particular CPU. E.g.,
        # for each case from a to f in the example below:
        #        |----------------------Timeline------------------------|
        #          |------------------Range (TID 1)------------------|
        #
        #             a             b                     c
        # CPU 1:   |-TID1-|      |-TID1-|              |-TID1-|
        # Perf:  |----PE----|----PE----|----PE----|----PE----|----PE----|
        #              1          2          3          4          5
        #
        #                    d                e                  f
        # CPU 2:          |-TID1-|      |----TID1------|      |-TID1-|
        # Perf:  |----PE----|----PE----|----PE----|----PE----|----PE----|
        #              1          2          3          4          5
        # Case a:
        # - For perf event 1 on CPU 1:
        #    count_scaler_a_1 = time(a) / time(PE 1).
        # Case b:
        # - For perf event 2 on CPU 1:
        #   count_scaler_b_2 = (end(PE 2) - start(b)) / time(PE 2)
        # - For perf event 3 on CPU 1:
        #   count_scaler_b_3 = (end(b) - start(PE 3)) / time(PE 3)
        # Case c:
        # - For perf event 4 on CPU 1:
        #   count_scaler_c_4 = (end(PE 4) - start(c)) / time(PE 4)
        # - For perf event 5 on CPU 1:
        #   count_scaler_c_5 = (end(c) - start(PE 5)) / time(PE 5)
        # Case d:
        # - For perf event 1 on CPU 2:
        #   count_scaler_d_1 = (end(PE 1) - start(d)) / time(PE 1)
        # - For perf event 2 on CPU 2:
        #   count_scaler_d_2 = (end(d) - start(PE 2)) / time(PE 2)
        # Case e:
        # - For perf event 3 on CPU 2:
        #   count_scaler_e_3 = (end(PE 3) - start(e)) / time(PE 3)
        # - For perf event 4 on CPU 2:
        #   count_scaler_e_4 = (end(e) - start(PE 4)) / time(PE 4)
        # Case f:
        # - For perf event 5 on CPU 2:
        #   count_scaler_f_5 = time(f) / time(PE 5)
        #
        # The total perf event value for the whole range
        # will be calculated as follows:
        # total value of perf event =
        #     count_scaler_a_1 * value(PE 1 on CPU 1) +
        #     count_scaler_b_2 * value(PE 2 on CPU 1) +
        #     count_scaler_b_3 * value(PE 3 on CPU 1) +
        #     count_scaler_c_4 * value(PE 4 on CPU 1) +
        #     count_scaler_c_5 * value(PE 5 on CPU 1) +
        #     count_scaler_d_1 * value(PE 1 on CPU 2) +
        #     count_scaler_d_2 * value(PE 2 on CPU 2) +
        #     count_scaler_e_3 * value(PE 3 on CPU 2) +
        #     count_scaler_e_4 * value(PE 4 on CPU 2) +
        #     count_scaler_f_5 * value(PE 5 on CPU 2)
        #
        # The total number of samples for the whole range
        # will be calculated per each CPU as follows:
        # total number of samples of perf event for CPU 1 =
        #     count_scaler_a_1 + count_scaler_b_2 + count_scaler_b_3 +
        #     count_scaler_c_4 + count_scaler_c_5
        # total number of samples of perf event for CPU 2 =
        #     count_scaler_d_1 + count_scaler_d_2 + count_scaler_e_3 +
        #     count_scaler_e_4 + count_scaler_f_5

        for perf_event, cpu_activity_range in zip(
            range_perf_df.itertuples(), range_cpu_df.itertuples()
        ):

            # Calculate the start and end times of the range portion
            # that does the actual work on the CPU.
            range_start = max(range.start, cpu_activity_range.start)
            range_end = min(range.end, cpu_activity_range.end)

            count_scaler = _get_perf_event_scaler_for_range(
                perf_event.start, perf_event.end, range_start, range_end
            )

            range_perf_events[perf_event.name] += round(count_scaler * perf_event.count)
            key = f"{perf_event.name}_{cpu_activity_range.cpu}"
            range_perf_samples[key] += count_scaler

        range_perf_events_df = pd.DataFrame(range_perf_events, index=[range.Index])
        range_perf_samples_df = pd.DataFrame(range_perf_samples, index=[range.Index])

        perf_events.append(range_perf_events_df)
        perf_samples.append(range_perf_samples_df)

    return pd.concat(perf_events), pd.concat(perf_samples)


def compute_core_perf_events(
    ranges_df, core_perf_df, thread_sched_df, rely_on_tid=True
):
    """
    Compute core perf events for the provided ranges.
    The output tuple will contain two DataFrames:
    (core perf events, their number of samples)
    collected by nsys with a row corresponding to the range.
    """

    # Get CPU activity ranges only for the PIDs
    # and TIDs (if rely_on_tid=True) of interest.
    if rely_on_tid:
        gtids = ranges_df["globalTid"].unique()
        thread_sched_df = thread_sched_df[thread_sched_df["globalTid"].isin(gtids)]
    else:
        pids = ranges_df["pid"].unique()
        thread_sched_df = thread_sched_df[thread_sched_df["pid"].isin(pids)]

    cpu_df = _get_cpu_activity_ranges(thread_sched_df)
    return _compute_core_perf_events(ranges_df, core_perf_df, cpu_df, rely_on_tid)


class Equation:
    class DataExtractor(ast.NodeVisitor):
        def __init__(self):
            self._operands = []

        def visit_Name(self, node):
            self._operands.append(node.id)
            return node

        @property
        def operands(self):
            return self._operands

    def __init__(self, equation_str):
        self._ast = ast.parse(equation_str, mode="eval")
        data_extractor = Equation.DataExtractor()
        data_extractor.visit(self._ast)
        self._operands = data_extractor.operands

    def run(self, df):
        context = {}
        for operand in self._operands:
            if operand not in df.columns:
                return None
            context[operand] = df[operand].values

        try:
            res = eval(
                compile(self._ast, filename="", mode="eval"),
                {"__builtins__": None},
                context,
            )
        except ZeroDivisionError as e:
            res = np.inf
        return res


def _parse_perf_metric_equations(metric_infos):
    equations = [None] * len(metric_infos)
    for idx, info in enumerate(metric_infos):
        if info is not None:
            equations[idx] = Equation(info.equation)
    return equations


_arm_metric_equations = None


def _get_metric_equations(cpu_arch):
    global _arm_metric_equations
    if cpu_arch == Architecture.AARCH64_SBSA:
        if _arm_metric_equations is None:
            _arm_metric_equations = _parse_perf_metric_equations(am.get_arm_metrics())
        return _arm_metric_equations
    return None


def compute_perf_metrics(ranges_df, time_column, cpu_arch):
    """
    Compute performance metrics for the provided ranges.
    The output DataFrame will contain performance metrics (as columns)
    with a row corresponding to the range.
    """
    ranges_df = ranges_df.copy()
    ranges_df["TIME"] = ranges_df[time_column]

    df = pd.DataFrame(index=ranges_df.index)
    equations = _get_metric_equations(cpu_arch)
    if equations is None:
        return df
    for idx, equation in enumerate(equations):
        if equation is not None:
            id_name = am.PerfMetricType(idx).name
            res = equation.run(ranges_df)
            if res is not None:
                df[id_name] = res
    return df
