from cython cimport floating, integral
from cython.parallel cimport parallel, prange
from libcpp.map cimport map as cpp_map, pair as cpp_pair
from libc.stdlib cimport free

from ...utils._typedefs cimport intp_t, float64_t
from ...utils.parallel import _get_threadpool_controller

import numpy as np
from scipy.sparse import issparse
from ._classmode cimport WeightingStrategy

{{for name_suffix in ["32", "64"]}}
from ._argkmin cimport ArgKmin{{name_suffix}}
from ._datasets_pair cimport DatasetsPair{{name_suffix}}

cdef class ArgKminClassMode{{name_suffix}}(ArgKmin{{name_suffix}}):
    """
    {{name_suffix}}bit implementation of ArgKminClassMode.
    """
    cdef:
        const intp_t[:] Y_labels,
        const intp_t[:] unique_Y_labels
        float64_t[:, :] class_scores
        cpp_map[intp_t, intp_t] labels_to_index
        WeightingStrategy weight_type

    @classmethod
    def compute(
        cls,
        X,
        Y,
        intp_t k,
        weights,
        Y_labels,
        unique_Y_labels,
        str metric="euclidean",
        chunk_size=None,
        dict metric_kwargs=None,
        str strategy=None,
    ):
        """Compute the argkmin reduction with Y_labels.

        This classmethod is responsible for introspecting the arguments
        values to dispatch to the most appropriate implementation of
        :class:`ArgKminClassMode{{name_suffix}}`.

        This allows decoupling the API entirely from the implementation details
        whilst maintaining RAII: all temporarily allocated datastructures necessary
        for the concrete implementation are therefore freed when this classmethod
        returns.

        No instance _must_ directly be created outside of this class method.
        """
        # Use a generic implementation that handles most scipy
        # metrics by computing the distances between 2 vectors at a time.
        pda = ArgKminClassMode{{name_suffix}}(
            datasets_pair=DatasetsPair{{name_suffix}}.get_for(X, Y, metric, metric_kwargs),
            k=k,
            chunk_size=chunk_size,
            strategy=strategy,
            weights=weights,
            Y_labels=Y_labels,
            unique_Y_labels=unique_Y_labels,
        )

        # Limit the number of threads in second level of nested parallelism for BLAS
        # to avoid threads over-subscription (in GEMM for instance).
        with _get_threadpool_controller().limit(limits=1, user_api="blas"):
            if pda.execute_in_parallel_on_Y:
                pda._parallel_on_Y()
            else:
                pda._parallel_on_X()

        return pda._finalize_results()

    def __init__(
        self,
        DatasetsPair{{name_suffix}} datasets_pair,
        const intp_t[:] Y_labels,
        const intp_t[:] unique_Y_labels,
        chunk_size=None,
        strategy=None,
        intp_t k=1,
        weights=None,
    ):
        super().__init__(
            datasets_pair=datasets_pair,
            chunk_size=chunk_size,
            strategy=strategy,
            k=k,
        )

        if weights == "uniform":
            self.weight_type = WeightingStrategy.uniform
        elif weights == "distance":
            self.weight_type = WeightingStrategy.distance
        else:
            self.weight_type = WeightingStrategy.callable
        self.Y_labels = Y_labels

        self.unique_Y_labels = unique_Y_labels

        cdef intp_t idx, neighbor_class_idx
        # Map from set of unique labels to their indices in `class_scores`
        # Buffer used in building a histogram for one-pass weighted mode
        self.class_scores = np.zeros(
            (self.n_samples_X, unique_Y_labels.shape[0]), dtype=np.float64,
        )

    def _finalize_results(self):
        probabilities = np.asarray(self.class_scores)
        probabilities /= probabilities.sum(axis=1, keepdims=True)
        return probabilities

    cdef inline void weighted_histogram_mode(
        self,
        intp_t sample_index,
        intp_t* indices,
        float64_t* distances,
   ) noexcept nogil:
        cdef:
            intp_t neighbor_idx, neighbor_class_idx, label_index, multi_output_index
            float64_t score_incr = 1
            # TODO: Implement other WeightingStrategy values
            bint use_distance_weighting = (
                self.weight_type == WeightingStrategy.distance
            )

        # Iterate through the sample k-nearest neighbours
        for neighbor_rank in range(self.k):
            # Absolute indice of the neighbor_rank-th Nearest Neighbors
            # in range [0, n_samples_Y)
            # TODO: inspect if it worth permuting this condition
            # and the for-loop above for improved branching.
            if use_distance_weighting:
                score_incr = 1 / distances[neighbor_rank]
            neighbor_idx = indices[neighbor_rank]
            neighbor_class_idx = self.Y_labels[neighbor_idx]
            self.class_scores[sample_index][neighbor_class_idx] += score_incr
        return

    cdef void _parallel_on_X_prange_iter_finalize(
        self,
        intp_t thread_num,
        intp_t X_start,
        intp_t X_end,
    ) noexcept nogil:
        cdef:
            intp_t idx, sample_index
        for idx in range(X_end - X_start):
            # One-pass top-one weighted mode
            # Compute the absolute index in [0, n_samples_X)
            sample_index = X_start + idx
            self.weighted_histogram_mode(
                sample_index,
                &self.heaps_indices_chunks[thread_num][idx * self.k],
                &self.heaps_r_distances_chunks[thread_num][idx * self.k],
            )
        return

    cdef void _parallel_on_Y_finalize(
        self,
    ) noexcept nogil:
        cdef:
            intp_t sample_index, thread_num

        with nogil, parallel(num_threads=self.chunks_n_threads):
            # Deallocating temporary datastructures
            for thread_num in prange(self.chunks_n_threads, schedule='static'):
                free(self.heaps_r_distances_chunks[thread_num])
                free(self.heaps_indices_chunks[thread_num])

            for sample_index in prange(self.n_samples_X, schedule='static'):
                self.weighted_histogram_mode(
                    sample_index,
                    &self.argkmin_indices[sample_index][0],
                    &self.argkmin_distances[sample_index][0],
                )
        return

{{endfor}}
