import json
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path
import numpy as np
from parse_billing import parse_billing_data

# Set up paths
BASE_PATH = Path('.')
AWS_PATH = BASE_PATH / 'aws'
IBM_PATH = BASE_PATH / 'ibm'

def load_json_file(path):
    """Load and parse a JSON file"""
    with open(path) as f:
        return json.load(f)

def parse_latency_data(path):
    """Parse node latency measurement data"""
    data = load_json_file(path)
    return pd.DataFrame(data)

# Load AWS homogeneous CAS data
aws_homo_cas_path = AWS_PATH / 'homogenous/cas/7e2387d0-1300-43e1-a016-41d8bdba88ae-eks-cas'
aws_cas_latency = parse_latency_data(aws_homo_cas_path / 'nodeLatencyMeasurement-api-intensive.json')
aws_cas_latency_quantiles = parse_latency_data(aws_homo_cas_path / 'nodeLatencyQuantilesMeasurement-api-intensive.json')
aws_cas_billing = parse_billing_data(aws_homo_cas_path / 'aws_billing_metrics_cas.json')

# Load AWS homogeneous Karpenter data
aws_homo_karp_path = AWS_PATH / 'homogenous/karpenter/7e2387d0-1300-43e1-a016-41d8bdba88ae-eks-karpenter'
aws_karp_latency = parse_latency_data(aws_homo_karp_path / 'nodeLatencyMeasurement-api-intensive.json')
aws_karp_latency_quantiles = parse_latency_data(aws_homo_karp_path / 'nodeLatencyQuantilesMeasurement-api-intensive.json')
aws_karp_billing = parse_billing_data(aws_homo_karp_path / 'aws_billing_metrics_karpenter.json')

def main():
    # 1. Node Provisioning Latency Comparison
    plt.figure(figsize=(12, 6))
    sns.boxplot(data=[aws_cas_latency['latency'], aws_karp_latency['latency']], 
               labels=['CAS', 'Karpenter'])
    plt.title('Node Provisioning Latency Distribution')
    plt.ylabel('Latency (seconds)')
    plt.savefig('plots/latency_distribution.png')
    plt.close()

    # 2. Cost Analysis
    plt.figure(figsize=(12, 6))
    cas_costs = aws_cas_billing['cost'].values
    karp_costs = aws_karp_billing['cost'].values
    plt.bar(['CAS', 'Karpenter'], [cas_costs.mean(), karp_costs.mean()])
    plt.title('Average Cost Comparison')
    plt.ylabel('Cost (USD)')
    plt.savefig('plots/cost_comparison.png')
    plt.close()

    # 3. Performance Quantiles Analysis
    plt.figure(figsize=(12, 6))
    quantiles = ['p50', 'p90', 'p95', 'p99']
    cas_values = [aws_cas_latency_quantiles[q].mean() for q in quantiles]
    karp_values = [aws_karp_latency_quantiles[q].mean() for q in quantiles]
    
    x = np.arange(len(quantiles))
    width = 0.35
    
    plt.bar(x - width/2, cas_values, width, label='CAS')
    plt.bar(x + width/2, karp_values, width, label='Karpenter')
    
    plt.xlabel('Percentile')
    plt.ylabel('Latency (seconds)')
    plt.title('Latency Quantiles Comparison')
    plt.xticks(x, quantiles)
    plt.legend()
    plt.savefig('plots/latency_quantiles.png')
    plt.close()

if __name__ == "__main__":
    main()
