"""
Network analysis module for 20min.ch comments.

This module provides functions for analyzing the network of commenters
and their interactions.
"""

from typing import Dict, List, Tuple, Any, Optional
import pandas as pd
import networkx as nx
import matplotlib.pyplot as plt
import seaborn as sns
from collections import Counter

class NetworkAnalyzer:
    """Analyze commenter networks from 20min.ch."""
    
    def __init__(self, comments_df: Optional[pd.DataFrame] = None):
        """
        Initialize the network analyzer.
        
        Args:
            comments_df: Optional DataFrame with comments
        """
        self.comments_df = comments_df
        self.graph = None
        
        if comments_df is not None:
            self.build_graph_from_df(comments_df)
    
    def build_graph_from_df(self, comments_df: pd.DataFrame) -> nx.DiGraph:
        """
        Build a directed graph from a DataFrame of comments.
        
        Args:
            comments_df: DataFrame with comments including parent_id info
            
        Returns:
            NetworkX DiGraph
        """
        if any(col not in comments_df.columns for col in ["comment_id", "author", "parent_id"]):
            raise ValueError("DataFrame must contain comment_id, author, and parent_id columns")
        
        # Create a directed graph
        G = nx.DiGraph()
        
        # Add all authors as nodes
        for author in comments_df["author"].unique():
            G.add_node(author, comment_count=0)
        
        # Update comment counts for each author
        for author, count in comments_df["author"].value_counts().items():
            G.nodes[author]["comment_count"] = count
        
        # Add edges for replies
        reply_rows = comments_df[comments_df["parent_id"].notna()]
        
        for _, row in reply_rows.iterrows():
            # Find the parent comment
            parent_comment = comments_df[comments_df["comment_id"] == row["parent_id"]]
            
            if not parent_comment.empty:
                parent_author = parent_comment.iloc[0]["author"]
                replier = row["author"]
                
                # Add edge from replier to parent author
                if G.has_edge(replier, parent_author):
                    G[replier][parent_author]["weight"] += 1
                else:
                    G.add_edge(replier, parent_author, weight=1)
        
        self.graph = G
        return G
    
    def get_graph_metrics(self) -> Dict[str, Any]:
        """
        Calculate basic metrics for the commenter graph.
        
        Returns:
            Dictionary with graph metrics
        """
        if self.graph is None:
            raise ValueError("Graph has not been built. Call build_graph_from_df first.")
        
        G = self.graph
        
        # Calculate basic metrics
        metrics = {
            "num_nodes": G.number_of_nodes(),
            "num_edges": G.number_of_edges(),
            "density": nx.density(G),
            "reciprocity": nx.reciprocity(G),
        }
        
        # Add degree statistics
        in_degrees = [d for _, d in G.in_degree()]
        out_degrees = [d for _, d in G.out_degree()]
        
        metrics.update({
            "avg_in_degree": sum(in_degrees) / len(in_degrees) if in_degrees else 0,
            "avg_out_degree": sum(out_degrees) / len(out_degrees) if out_degrees else 0,
            "max_in_degree": max(in_degrees) if in_degrees else 0,
            "max_out_degree": max(out_degrees) if out_degrees else 0,
        })
        
        # Try to calculate connected components and other metrics if graph is not empty
        if G.number_of_nodes() > 0:
            try:
                # Calculate weakly connected components
                wcc = list(nx.weakly_connected_components(G))
                metrics["num_connected_components"] = len(wcc)
                metrics["largest_component_size"] = len(max(wcc, key=len))
                
                # Calculate clustering coefficient (undirected version for better interpretability)
                undir_G = G.to_undirected()
                metrics["avg_clustering"] = nx.average_clustering(undir_G)
                
            except Exception as e:
                metrics["error"] = str(e)
        
        return metrics
    
    def get_central_commenters(self, top_n: int = 10, metric: str = "degree") -> pd.DataFrame:
        """
        Get most central commenters based on network metrics.
        
        Args:
            top_n: Number of top commenters to return
            metric: Centrality metric to use 
                    ("degree", "in_degree", "out_degree", "betweenness", "eigenvector")
            
        Returns:
            DataFrame with top commenters by centrality
        """
        if self.graph is None:
            raise ValueError("Graph has not been built. Call build_graph_from_df first.")
        
        G = self.graph
        
        # Calculate centrality based on chosen metric
        if metric == "degree":
            centrality = nx.degree_centrality(G)
        elif metric == "in_degree":
            centrality = nx.in_degree_centrality(G)
        elif metric == "out_degree":
            centrality = nx.out_degree_centrality(G)
        elif metric == "betweenness":
            centrality = nx.betweenness_centrality(G)
        elif metric == "eigenvector":
            try:
                centrality = nx.eigenvector_centrality(G, max_iter=1000)
            except nx.PowerIterationFailedConvergence:
                centrality = nx.eigenvector_centrality_numpy(G)
        else:
            raise ValueError(f"Unknown centrality metric: {metric}")
        
        # Convert to DataFrame and get top N
        centrality_df = pd.DataFrame.from_dict(centrality, orient="index", columns=[f"{metric}_centrality"])
        centrality_df = centrality_df.sort_values(f"{metric}_centrality", ascending=False).head(top_n)
        
        # Add node attributes if available
        if "comment_count" in G.nodes[list(G.nodes)[0]]:
            comment_counts = {node: data["comment_count"] for node, data in G.nodes(data=True)}
            centrality_df["comment_count"] = pd.Series(comment_counts)
        
        # Add in and out degree
        in_degrees = {node: degree for node, degree in G.in_degree()}
        out_degrees = {node: degree for node, degree in G.out_degree()}
        
        centrality_df["in_degree"] = pd.Series(in_degrees)
        centrality_df["out_degree"] = pd.Series(out_degrees)
        
        return centrality_df
    
    def find_communities(self, algorithm: str = "louvain") -> Dict[str, Any]:
        """
        Detect communities in the commenter network.
        
        Args:
            algorithm: Community detection algorithm to use
                       ("louvain", "greedy_modularity", "label_propagation")
            
        Returns:
            Dictionary with community information
        """
        if self.graph is None:
            raise ValueError("Graph has not been built. Call build_graph_from_df first.")
        
        G = self.graph
        
        # Convert to undirected graph for community detection
        undir_G = G.to_undirected()
        
        # Detect communities using the specified algorithm
        if algorithm == "louvain":
            try:
                from community import best_partition
                partition = best_partition(undir_G)
            except ImportError:
                # Fallback if python-louvain is not installed
                algorithm = "greedy_modularity"
        
        if algorithm == "greedy_modularity":
            communities = list(nx.community.greedy_modularity_communities(undir_G))
            partition = {}
            for i, comm in enumerate(communities):
                for node in comm:
                    partition[node] = i
        
        elif algorithm == "label_propagation":
            communities = list(nx.community.label_propagation_communities(undir_G))
            partition = {}
            for i, comm in enumerate(communities):
                for node in comm:
                    partition[node] = i
        
        # Count nodes in each community
        community_sizes = Counter(partition.values())
        
        # Get top commenters for each community
        community_leaders = {}
        for comm_id in range(len(community_sizes)):
            # Get members of this community
            members = [node for node, comm in partition.items() if comm == comm_id]
            
            # Get subgraph for this community
            subgraph = G.subgraph(members)
            
            # Calculate degree centrality in the original graph
            centrality = nx.degree_centrality(G)
            
            # Filter for community members
            comm_centrality = {node: centrality[node] for node in members}
            
            # Get top commenters by centrality
            top_commenters = sorted(comm_centrality.items(), key=lambda x: x[1], reverse=True)[:5]
            community_leaders[comm_id] = top_commenters
        
        return {
            "num_communities": len(community_sizes),
            "community_sizes": dict(community_sizes),
            "community_leaders": community_leaders,
            "node_community": partition
        }
    
    def detect_potential_suspicious_activity(self) -> pd.DataFrame:
        """
        Detect potentially suspicious commenting patterns.
        
        This looks for commenters with unusual activity patterns, such as:
        - Very high reply rates to specific users
        - Users who form tight clusters with high reciprocity
        - Users who mostly engage in negative sentiment exchanges
        
        Returns:
            DataFrame with potentially suspicious commenters
        """
        if self.graph is None or self.comments_df is None:
            raise ValueError("Both graph and comments DataFrame are needed for this analysis")
        
        G = self.graph
        suspicious_users = []
        
        # 1. Detect users with unusually high number of interactions with specific users
        for node in G.nodes():
            # Get out-edges
            out_edges = list(G.out_edges(node, data=True))
            
            if len(out_edges) < 3:
                continue
            
            # Count replies to each target
            target_counts = Counter([target for _, target, _ in out_edges])
            
            # If a user directs more than 50% of their replies to a single user
            most_common = target_counts.most_common(1)
            if most_common and (most_common[0][1] / len(out_edges)) > 0.5:
                suspicious_users.append({
                    "author": node,
                    "suspicious_pattern": "targeted_replies",
                    "details": f"{most_common[0][1]}/{len(out_edges)} replies to {most_common[0][0]}",
                    "target_user": most_common[0][0],
                    "concentration_ratio": most_common[0][1] / len(out_edges)
                })
        
        # 2. Detect users with high activity but low connectedness (possible sock puppets)
        in_degrees = {node: deg for node, deg in G.in_degree()}
        out_degrees = {node: deg for node, deg in G.out_degree()}
        
        for node in G.nodes():
            comment_count = G.nodes[node].get("comment_count", 0)
            
            # Skip users with few comments
            if comment_count < 10:
                continue
            
            # If a user has many comments but few interactions
            in_degree = in_degrees.get(node, 0)
            out_degree = out_degrees.get(node, 0)
            
            interaction_ratio = (in_degree + out_degree) / comment_count
            
            if interaction_ratio < 0.1:  # Less than 10% of comments involve interactions
                suspicious_users.append({
                    "author": node,
                    "suspicious_pattern": "isolated_activity",
                    "details": f"{comment_count} comments but only {in_degree+out_degree} interactions",
                    "comment_count": comment_count,
                    "interaction_count": in_degree + out_degree,
                    "interaction_ratio": interaction_ratio
                })
        
        # 3. Check for groups of users who primarily interact with each other
        # This would need community detection + analysis of inter vs intra community links
        
        return pd.DataFrame(suspicious_users)
    
    def plot_network(self, figsize=(12, 10), node_size_factor=50, min_edge_weight=1, title=None):
        """
        Plot the commenter network graph.
        
        Args:
            figsize: Figure size as (width, height) tuple
            node_size_factor: Factor to scale node sizes (based on comment count)
            min_edge_weight: Minimum edge weight to include
            title: Optional title for the plot
            
        Returns:
            matplotlib figure and axis
        """
        if self.graph is None:
            raise ValueError("Graph has not been built. Call build_graph_from_df first.")
        
        G = self.graph.copy()
        
        # Remove edges below minimum weight
        if min_edge_weight > 1:
            edges_to_remove = [(u, v) for u, v, d in G.edges(data=True) if d.get('weight', 1) < min_edge_weight]
            G.remove_edges_from(edges_to_remove)
        
        # Remove isolated nodes
        G.remove_nodes_from(list(nx.isolates(G)))
        
        if G.number_of_nodes() == 0:
            print("No nodes to display after filtering.")
            return None, None
        
        # Create the figure
        fig, ax = plt.subplots(figsize=figsize)
        
        # Position nodes using Fruchterman-Reingold force-directed algorithm
        pos = nx.spring_layout(G, k=0.3, seed=42)
        
        # Get node sizes based on comment count
        node_sizes = [G.nodes[node].get('comment_count', 1) * node_size_factor for node in G.nodes()]
        
        # Get edge weights
        edge_weights = [G[u][v].get('weight', 1) for u, v in G.edges()]
        
        # Draw the network
        nx.draw_networkx_nodes(G, pos, node_size=node_sizes, node_color='skyblue', alpha=0.8, ax=ax)
        nx.draw_networkx_edges(G, pos, width=edge_weights, alpha=0.5, edge_color='gray', ax=ax)
        
        # Add labels to nodes with larger comment counts
        large_commenters = [n for n in G.nodes() if G.nodes[n].get('comment_count', 0) > G.number_of_nodes() // 10]
        nx.draw_networkx_labels(G, pos, {n: n for n in large_commenters}, font_size=10, ax=ax)
        
        if title:
            plt.title(title)
        plt.axis('off')
        
        return fig, ax 