import pandas as pd
import numpy as np
import re
import matplotlib.pyplot as plt
from constants import HES


def categorize_title(title: str, patterns: dict) -> str:
    """
    Categorize a title based on a dictionary of patterns.

    Parameters:
    title (str): The title to categorize.
    patterns (dict): A dictionary where the keys are the categories and the values are the patterns to match.

    Returns:
    str: The category of the title."""
    for category, pattern in patterns.items():
        if re.search(pattern, title):
            return category
    return "Uncategorized"  # For rows that don't fit any of the patterns


def get_category(
    df: pd.DataFrame, column: str, categories: list, cat: str
) -> pd.DataFrame:
    """ "
    Get a subset of a DataFrame based on the category of the titles in a column.

    Parameters:
    df (pd.DataFrame): The DataFrame to filter.
    column (str): The column containing the titles.
    categories (list): A list of categories.

    Returns:
    pd.DataFrame: The subset of the DataFrame that matches the category."""

    patterns = {
        categories[0]: r"^(?!.*\b\d{4}\b).*$",  # No 4-digit year anywhere in the title
        categories[1]: r"^\b\d{4}\b$",  # Starts with a 4-digit year and nothing else
        categories[
            2
        ]: r"^\b\d{4}\b.*\bQ[1-4]\b",  # Starts with a year and contains "Q1", "Q2", etc.
        categories[
            3
        ]: r"^\b\d{4}\b.*\b(JAN|FEB|MAR|APR|MAY|JUN|JUL|AUG|SEP|OCT|NOV|DEC)\b",  # Starts with a year and contains a month name
    }

    df["category"] = df[column].apply(categorize_title, patterns=patterns)
    result = df[df["category"] == cat]
    result = result.drop(columns=["category"])
    result = result.dropna()
    result.columns = ["date", "value"]
    result["date"] = pd.to_datetime(result["date"], format="%Y %b")
    result["value"] = result["value"].astype(float)
    result = result.reset_index(drop=True)
    return result


def read_cpih(
    file_path: str, medical: bool = True, category: str = "Month"
) -> pd.DataFrame:
    """
    Read the CPIH data from a CSV file and return a DataFrame.

    Parameters:
    file_path (str): The path to the CSV file.
    category (str): The category of the data to extract.

    Returns:
    pd.DataFrame: The CPIH data."""
    return get_category(
        pd.read_csv(file_path), "Title", ["Month", "Year", "Quarter", "Month"], category
    )


def read_hes(
    file_path: str,
):
    """
    Read the HES data from a CSV file and return a DataFrame.

    Parameters:
    file_path (str): The path to the CSV file.

    Returns:
    pd.DataFrame: The HES data."""
    df = pd.read_csv(file_path)
    df["CALENDAR_MONTH_END_DATE"] = df["CALENDAR_MONTH_END_DATE"].str.replace(
        "-", " 20"
    )
    df["CALENDAR_MONTH_END_DATE"] = df["CALENDAR_MONTH_END_DATE"].str.upper()
    df["CALENDAR_MONTH_END_DATE"] = pd.to_datetime(
        df["CALENDAR_MONTH_END_DATE"], format="mixed"
    )
    df = df.dropna(axis=1, how="all")
    df = df.sort_values(by="CALENDAR_MONTH_END_DATE")
    df = df.dropna()
    df = df.reset_index(drop=True)
    df.rename(columns={"CALENDAR_MONTH_END_DATE": "date"}, inplace=True)
    df["date"] = pd.to_datetime(df["date"], format="%Y %b")
    df["date"] = df["date"] + pd.offsets.MonthBegin(-1)
    return df


def get_global_df(
    cpih: pd.DataFrame, cpim: pd.DataFrame, hes: pd.DataFrame
) -> pd.DataFrame:
    """
    Merge the CPIH, CPIM and HES data into a single DataFrame.

    Parameters:
    cpih (pd.DataFrame): The CPIH data.
    cpim (pd.DataFrame): The CPIM data.
    hes (pd.DataFrame): The HES data.

    Returns:
    pd.DataFrame: The merged DataFrame."""
    joined_data = pd.merge(cpih, cpim, on="date", how="inner").merge(
        hes, on="date", how="inner"
    )
    joined_data.rename(
        columns={"value_x": "cpih", "value_y": "cpih_medical"}, inplace=True
    )
    joined_data["year"] = joined_data["date"].dt.year
    joined_data["month"] = joined_data["date"].dt.month
    joined_data.drop(columns=["date"], inplace=True)
    return joined_data


def get_final_df(joined_data: pd.DataFrame) -> pd.DataFrame:
    """
    Create the final DataFrame for training and testing.

    Parameters:
    joined_data (pd.DataFrame): The merged DataFrame.

    Returns:
    pd.DataFrame: The final DataFrame."""
    joined_data["date"] = pd.to_datetime(joined_data[["year", "month"]].assign(day=1))
    final_data = pd.DataFrame(columns=["date"])
    final_data["date"] = joined_data["date"]
    final_data["target"] = joined_data["cpih_medical"]
    final_data["cpim_lag1"] = joined_data["cpih_medical"].shift(1)
    final_data["cpim_lag2"] = joined_data["cpih_medical"].shift(2)
    final_data["cpim_lag3"] = joined_data["cpih_medical"].shift(3)
    final_data["cpih_lag1"] = joined_data["cpih"].shift(1)
    final_data["cpih_lag2"] = joined_data["cpih"].shift(2)
    final_data["cpih_lag3"] = joined_data["cpih"].shift(3)
    final_data[HES] = joined_data[HES].shift(1)
    final_data.dropna(inplace=True)
    final_data.reset_index(drop=True, inplace=True)
    return final_data