import pandas as pd
import os
from pathlib import Path
from typing import Union
import seaborn as sns
import numpy as np
import scipy.stats
import matplotlib.pyplot as plt
import tkinter as tk
from tkinter import filedialog, simpledialog, messagebox, ttk


DEBUG = False

pd_ant = pd.read_csv("data_sets/TxAntennaDAB.csv", engine='python')
pd_param = pd.read_csv("data_sets/TxParamsDAB.csv", encoding="iso-8859-1")

dataframe = pd.merge(pd_ant, pd_param, on='id')

def save_as_json(df: pd.DataFrame, filepath: str) -> Union[str, None]:
    return df.to_json(filepath)

def clean_data(df : pd.DataFrame) -> pd.DataFrame:
    excluded_NGRs = [
        "NZ02553847",
        "SE213515",
        "NT05399374",
        "NT252675908"
    ]
    for x in excluded_NGRs:
        df = df.drop(df[df["NGR"] == x].index)
        df = df.drop(df[df["Date"].isnull()].index)

    return df

def load_data():
    file_path = filedialog.askopenfilename(filetypes=[("CSV files", "*.csv")])
    dataset_name = simpledialog.askstring("Input", "Enter the name for the dataset", initialvalue=os.path.basename(file_path))
    data_list.insert(tk.END, dataset_name)
    data_paths.append(file_path)
    messagebox.showinfo("Success", "Data loaded successfully")
    check_dataset_selected()

def load_and_save_prepared_data():
    file_path = filedialog.askopenfilename(filetypes=[("CSV files", "*.csv")])
    dataset_name = simpledialog.askstring("Input", "Enter the name for the dataset", initialvalue=os.path.basename(file_path))
    data_list.insert(tk.END, dataset_name)
    data_paths.append(file_path)
    messagebox.showinfo("Success", "Prepared data loaded and saved successfully")


def generate_stats (df : pd.DataFrame) -> pd.DataFrame:
    df['Date'] = df['Date'].str.split('/').str[-1].astype(int)
    df['In-Use ERP Total'] = df['In-Use ERP Total'].str.replace('.', '').str.replace(',', '.').astype(float)

    filtered_df = df[
        (df['Site Height'] > 75) & 
        (df['Date'].astype(str).str[:4].astype(int) >= 2001) & 
        (df['EID'].isin(['C18A', 'C18F', 'C188']))
    ]

    mean_val = filtered_df['In-Use ERP Total'].mean()
    mode_val = filtered_df['In-Use ERP Total'].mode()[0]
    median_val = filtered_df['In-Use ERP Total'].median()

    return filtered_df

def get_mean_mode_median(df : pd.DataFrame, key="In-Use ERP Total"):
    return [
        df[key].mean(),
        df[key].mode()[0],
        df[key].median()
    ]

def melt_df(df : pd.DataFrame):
    melted_df = pd.melt(df, 
                        id_vars=['Site', 'Freq.', 'Block'], 
                        value_vars=['Serv Label1 ', 'Serv Label2 ', 'Serv Label3 ', 'Serv Label4 ', 'Serv Label10 '],
                        var_name='Service Label Type', 
                        value_name='Service Label Value')
    return melted_df

def catplot_df(df : pd.DataFrame):
    g = sns.catplot(x="Block", 
                    y="Service Label Value", 
                    hue="Service Label Type",
                    col="Site", 
                    data=df, 
                    kind="strip", 
                    height=4, 
                    aspect=1)

    g.set_xticklabels(rotation=45)
    plt.show()

def facetplot_df(df : pd.DataFrame):
    g = sns.FacetGrid(df, col="Service Label Type", col_wrap=3, height=4, sharey=False)
    g.map(sns.stripplot, "Block", "Service Label Value", "Site", jitter=True, palette="Set2", order=None)
    g.set_axis_labels("Block", "Service Label Value")
    plt.show()

def boxplot_df(df : pd.DataFrame):
    plt.figure(figsize=(15, 8))
    sns.boxplot(data=df, x="Block", y="Freq.", hue="Service Label Type")
    plt.title('Box Plot of Service Labels across Blocks')
    plt.show()

def scatterplot_df(df : pd.DataFrame):
    plt.figure(figsize=(6, 2))
    sns.scatterplot(data=df, 
                    x="Block", 
                    y="Service Label Value", 
                    hue="Service Label Type", 
                    size="Freq.",
                    sizes=(10, 200),
                    alpha=0.7)
    plt.title('Scatter Plot of Service Labels across Blocks by Frequency')
    plt.xticks(rotation=45)
    plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
    plt.tight_layout()
    plt.show()

def check_dataset_selected(*args):
    selected = data_list.curselection()
    if selected:
        state = 'normal'
        file_path = data_paths[selected[0]]
        df = pd.read_csv(file_path)

        for row in data_preview.get_children():
            data_preview.delete(row)

        data_preview["column"] = list(df.columns)
        data_preview["show"] = "headings"
        for column in data_preview["columns"]:
            data_preview.heading(column, text=column)

        df_rows = df.to_numpy().tolist()
        for row in df_rows:
            data_preview.insert("", "end", values=row)
    else:
        state = 'disabled'
        for row in data_preview.get_children():
            data_preview.delete(row)

    load_and_save_button['state'] = state
    generate_output_button['state'] = state
    manipulate_values_button['state'] = state

def check_correlation(df: pd.DataFrame):
    for col in df.columns:
        if df[col].dtype == "object":
            df[col] = df[col].astype('category').cat.codes
    
    df_pivot = df.pivot_table(index=['Site', 'Freq.', 'Block'], 
                                     columns='Service Label Type', 
                                     values='Service Label Value', 
                                     aggfunc='first').reset_index()
    
    correlation_matrix = df_pivot.corr()

    correlations = correlation_matrix.values
    n = len(df_pivot)
    significance_level = 0.05

    for i in range(correlation_matrix.shape[0]):
        for j in range(correlation_matrix.shape[1]):
            if i != j:
                r = correlations[i, j]
                t_statistic = r * np.sqrt((n-2)/(1-r**2))
                df = n - 2
                p_value = 2 * (1 - scipy.stats.t.cdf(np.abs(t_statistic), df))
                
                if p_value < significance_level:
                    print(f"The correlation between {correlation_matrix.columns[i]} and {correlation_matrix.columns[j]} is statistically significant.")
                else:
                    print(f"The correlation between {correlation_matrix.columns[i]} and {correlation_matrix.columns[j]} is not statistically significant.")

def graph(df : pd.DataFrame):
    cleaned_data = clean_data(dataframe)
    stats = generate_stats(cleaned_data)
    melted = melt_df(stats)
    # catplot_df(melted)
    # facetplot_df(melted)
    # boxplot_df(melted)
    scatterplot_df(melted)
graph(dataframe)

if __name__ == "__main__":
    root = tk.Tk()
    root.geometry("1000x600")
    root.title("Data Processing Application")
    root.config(bg="white")

    style = ttk.Style(root)
    style.theme_use('clam')

    style.configure('.', font=('Helvetica', 12))

    frame = ttk.Frame(root, padding="10 10 10 10")
    frame.grid(row=0, column=0, sticky='nsew')

    load_data_button = ttk.Button(frame, text="Load dataset", command=load_data)
    load_data_button.pack(fill='x', pady=5)

    load_and_save_button = ttk.Button(frame, text="Load and save prepared data", command=load_and_save_prepared_data, state='disabled')
    load_and_save_button.pack(fill='x', pady=5)

    data_list = tk.Listbox(root, font=('Helvetica', 12), bg='white')
    data_list.grid(row=0, column=1, sticky='nsew', padx=5, pady=5)
    data_list.bind('<<ListboxSelect>>', check_dataset_selected)

    data_preview = ttk.Treeview(root)
    data_preview.grid(row=0, column=2, sticky='nsew', padx=5, pady=5)

    quit_button = ttk.Button(root, text="Quit", command=root.destroy)
    quit_button.grid(row=1, column=0, sticky='s', padx=5, pady=5)

    root.grid_columnconfigure(0, weight=1)
    root.grid_columnconfigure(1, weight=3)
    root.grid_columnconfigure(2, weight=5)

    for i in range(2):
        root.grid_rowconfigure(i, weight=1)

    data_paths = []

    root.mainloop()

if DEBUG:
    pre_func_call_shape = dataframe.shape

    dataframe = clean_data(dataframe)
    assert dataframe.shape != pre_func_call_shape, "the shape is the same :("

    breakpo