from libc.stdlib cimport free, malloc
from libc.float cimport DBL_MAX
from cython cimport final
from cython.parallel cimport parallel, prange

from ...utils._heap cimport heap_push
from ...utils._sorting cimport simultaneous_sort
from ...utils._typedefs cimport intp_t, float64_t

import numpy as np
import warnings

from numbers import Integral
from scipy.sparse import issparse
from ...utils import check_array, check_scalar
from ...utils.fixes import _in_unstable_openblas_configuration
from ...utils.parallel import _get_threadpool_controller

{{for name_suffix in ['64', '32']}}

from ._base cimport (
    BaseDistancesReduction{{name_suffix}},
    _sqeuclidean_row_norms{{name_suffix}},
)

from ._datasets_pair cimport DatasetsPair{{name_suffix}}

from ._middle_term_computer cimport MiddleTermComputer{{name_suffix}}


cdef class ArgKmin{{name_suffix}}(BaseDistancesReduction{{name_suffix}}):
    """float{{name_suffix}} implementation of the ArgKmin."""

    @classmethod
    def compute(
        cls,
        X,
        Y,
        intp_t k,
        metric="euclidean",
        chunk_size=None,
        dict metric_kwargs=None,
        str strategy=None,
        bint return_distance=False,
    ):
        """Compute the argkmin reduction.

        This classmethod is responsible for introspecting the arguments
        values to dispatch to the most appropriate implementation of
        :class:`ArgKmin{{name_suffix}}`.

        This allows decoupling the API entirely from the implementation details
        whilst maintaining RAII: all temporarily allocated datastructures necessary
        for the concrete implementation are therefore freed when this classmethod
        returns.

        No instance should directly be created outside of this class method.
        """
        # Limit the number of threads in second level of nested parallelism for BLAS
        # to avoid threads over-subscription (in DOT or GEMM for instance).
        with _get_threadpool_controller().limit(limits=1, user_api='blas'):
          if metric in ("euclidean", "sqeuclidean"):
              # Specialized implementation of ArgKmin for the Euclidean distance
              # for the dense-dense and sparse-sparse cases.
              # This implementation computes the distances by chunk using
              # a decomposition of the Squared Euclidean distance.
              # This specialisation has an improved arithmetic intensity for both
              # the dense and sparse settings, allowing in most case speed-ups of
              # several orders of magnitude compared to the generic ArgKmin
              # implementation.
              # Note that squared norms of X and Y are precomputed in the
              # constructor of this class by issuing BLAS calls that may use
              # multithreading (depending on the BLAS implementation), hence calling
              # the constructor needs to be protected under the threadpool_limits
              # context, along with the main calls to _parallel_on_Y and
              # _parallel_on_X.
              # For more information see MiddleTermComputer.
              use_squared_distances = metric == "sqeuclidean"
              pda = EuclideanArgKmin{{name_suffix}}(
                  X=X, Y=Y, k=k,
                  use_squared_distances=use_squared_distances,
                  chunk_size=chunk_size,
                  strategy=strategy,
                  metric_kwargs=metric_kwargs,
              )
          else:
              # Fall back on a generic implementation that handles most scipy
              # metrics by computing the distances between 2 vectors at a time.
              pda = ArgKmin{{name_suffix}}(
                  datasets_pair=DatasetsPair{{name_suffix}}.get_for(X, Y, metric, metric_kwargs),
                  k=k,
                  chunk_size=chunk_size,
                  strategy=strategy,
              )

          if pda.execute_in_parallel_on_Y:
              pda._parallel_on_Y()
          else:
              pda._parallel_on_X()

        return pda._finalize_results(return_distance)

    def __init__(
        self,
        DatasetsPair{{name_suffix}} datasets_pair,
        chunk_size=None,
        strategy=None,
        intp_t k=1,
    ):
        super().__init__(
            datasets_pair=datasets_pair,
            chunk_size=chunk_size,
            strategy=strategy,
        )
        self.k = check_scalar(k, "k", Integral, min_val=1)

        # Allocating pointers to datastructures but not the datastructures themselves.
        # There are as many pointers as effective threads.
        #
        # For the sake of explicitness:
        #   - when parallelizing on X, the pointers of those heaps are referencing
        #   (with proper offsets) addresses of the two main heaps (see below)
        #   - when parallelizing on Y, the pointers of those heaps are referencing
        #   small heaps which are thread-wise-allocated and whose content will be
        #   merged with the main heaps'.
        self.heaps_r_distances_chunks = <float64_t **> malloc(
            sizeof(float64_t *) * self.chunks_n_threads
        )
        self.heaps_indices_chunks = <intp_t **> malloc(
            sizeof(intp_t *) * self.chunks_n_threads
        )

        # Main heaps which will be returned as results by `ArgKmin{{name_suffix}}.compute`.
        self.argkmin_indices = np.full((self.n_samples_X, self.k), 0, dtype=np.intp)
        self.argkmin_distances = np.full((self.n_samples_X, self.k), DBL_MAX, dtype=np.float64)

    def __dealloc__(self):
        if self.heaps_indices_chunks is not NULL:
            free(self.heaps_indices_chunks)

        if self.heaps_r_distances_chunks is not NULL:
            free(self.heaps_r_distances_chunks)

    cdef void _compute_and_reduce_distances_on_chunks(
        self,
        intp_t X_start,
        intp_t X_end,
        intp_t Y_start,
        intp_t Y_end,
        intp_t thread_num,
    ) noexcept nogil:
        cdef:
            intp_t i, j
            intp_t n_samples_X = X_end - X_start
            intp_t n_samples_Y = Y_end - Y_start
            float64_t *heaps_r_distances = self.heaps_r_distances_chunks[thread_num]
            intp_t *heaps_indices = self.heaps_indices_chunks[thread_num]

        # Pushing the distances and their associated indices on a heap
        # which by construction will keep track of the argkmin.
        for i in range(n_samples_X):
            for j in range(n_samples_Y):
                heap_push(
                    values=heaps_r_distances + i * self.k,
                    indices=heaps_indices + i * self.k,
                    size=self.k,
                    val=self.datasets_pair.surrogate_dist(X_start + i, Y_start + j),
                    val_idx=Y_start + j,
                )

    cdef void _parallel_on_X_init_chunk(
        self,
        intp_t thread_num,
        intp_t X_start,
        intp_t X_end,
    ) noexcept nogil:
        # As this strategy is embarrassingly parallel, we can set each
        # thread's heaps pointer to the proper position on the main heaps.
        self.heaps_r_distances_chunks[thread_num] = &self.argkmin_distances[X_start, 0]
        self.heaps_indices_chunks[thread_num] = &self.argkmin_indices[X_start, 0]

    cdef void _parallel_on_X_prange_iter_finalize(
        self,
        intp_t thread_num,
        intp_t X_start,
        intp_t X_end,
    ) noexcept nogil:
        cdef:
            intp_t idx

        # Sorting the main heaps portion associated to `X[X_start:X_end]`
        # in ascending order w.r.t the distances.
        for idx in range(X_end - X_start):
            simultaneous_sort(
                self.heaps_r_distances_chunks[thread_num] + idx * self.k,
                self.heaps_indices_chunks[thread_num] + idx * self.k,
                self.k
            )

    cdef void _parallel_on_Y_init(
        self,
    ) noexcept nogil:
        cdef:
            # Maximum number of scalar elements (the last chunks can be smaller)
            intp_t heaps_size = self.X_n_samples_chunk * self.k
            intp_t thread_num

        # The allocation is done in parallel for data locality purposes: this way
        # the heaps used in each threads are allocated in pages which are closer
        # to the CPU core used by the thread.
        # See comments about First Touch Placement Policy:
        # https://www.openmp.org/wp-content/uploads/openmp-webinar-vanderPas-20210318.pdf #noqa
        for thread_num in prange(self.chunks_n_threads, schedule='static', nogil=True,
                                 num_threads=self.chunks_n_threads):
            # As chunks of X are shared across threads, so must their
            # heaps. To solve this, each thread has its own heaps
            # which are then synchronised back in the main ones.
            self.heaps_r_distances_chunks[thread_num] = <float64_t *> malloc(
                heaps_size * sizeof(float64_t)
            )
            self.heaps_indices_chunks[thread_num] = <intp_t *> malloc(
                heaps_size * sizeof(intp_t)
            )

    cdef void _parallel_on_Y_parallel_init(
        self,
        intp_t thread_num,
        intp_t X_start,
        intp_t X_end,
    ) noexcept nogil:
        # Initialising heaps (memset can't be used here)
        for idx in range(self.X_n_samples_chunk * self.k):
            self.heaps_r_distances_chunks[thread_num][idx] = DBL_MAX
            self.heaps_indices_chunks[thread_num][idx] = -1

    @final
    cdef void _parallel_on_Y_synchronize(
        self,
        intp_t X_start,
        intp_t X_end,
    ) noexcept nogil:
        cdef:
            intp_t idx, jdx, thread_num
        with nogil, parallel(num_threads=self.effective_n_threads):
            # Synchronising the thread heaps with the main heaps.
            # This is done in parallel sample-wise (no need for locks).
            #
            # This might break each thread's data locality as each heap which
            # was allocated in a thread is being now being used in several threads.
            #
            # Still, this parallel pattern has shown to be efficient in practice.
            for idx in prange(X_end - X_start, schedule="static"):
                for thread_num in range(self.chunks_n_threads):
                    for jdx in range(self.k):
                        heap_push(
                            values=&self.argkmin_distances[X_start + idx, 0],
                            indices=&self.argkmin_indices[X_start + idx, 0],
                            size=self.k,
                            val=self.heaps_r_distances_chunks[thread_num][idx * self.k + jdx],
                            val_idx=self.heaps_indices_chunks[thread_num][idx * self.k + jdx],
                        )

    cdef void _parallel_on_Y_finalize(
        self,
    ) noexcept nogil:
        cdef:
            intp_t idx, thread_num

        with nogil, parallel(num_threads=self.chunks_n_threads):
            # Deallocating temporary datastructures
            for thread_num in prange(self.chunks_n_threads, schedule='static'):
                free(self.heaps_r_distances_chunks[thread_num])
                free(self.heaps_indices_chunks[thread_num])

            # Sorting the main in ascending order w.r.t the distances.
            # This is done in parallel sample-wise (no need for locks).
            for idx in prange(self.n_samples_X, schedule='static'):
                simultaneous_sort(
                    &self.argkmin_distances[idx, 0],
                    &self.argkmin_indices[idx, 0],
                    self.k,
                )
        return

    cdef void compute_exact_distances(self) noexcept nogil:
        cdef:
            intp_t i, j
            float64_t[:, ::1] distances = self.argkmin_distances
        for i in prange(self.n_samples_X, schedule='static', nogil=True,
                        num_threads=self.effective_n_threads):
            for j in range(self.k):
                distances[i, j] = self.datasets_pair.distance_metric._rdist_to_dist(
                    # Guard against potential -0., causing nan production.
                    max(distances[i, j], 0.)
                )

    def _finalize_results(self, bint return_distance=False):
        if return_distance:
            # We need to recompute distances because we relied on
            # surrogate distances for the reduction.
            self.compute_exact_distances()

            # Values are returned identically to the way `KNeighborsMixin.kneighbors`
            # returns values. This is counter-intuitive but this allows not using
            # complex adaptations where `ArgKmin.compute` is called.
            return np.asarray(self.argkmin_distances), np.asarray(self.argkmin_indices)

        return np.asarray(self.argkmin_indices)


cdef class EuclideanArgKmin{{name_suffix}}(ArgKmin{{name_suffix}}):
    """EuclideanDistance-specialisation of ArgKmin{{name_suffix}}."""

    @classmethod
    def is_usable_for(cls, X, Y, metric) -> bool:
        return (ArgKmin{{name_suffix}}.is_usable_for(X, Y, metric) and
                not _in_unstable_openblas_configuration())

    def __init__(
        self,
        X,
        Y,
        intp_t k,
        bint use_squared_distances=False,
        chunk_size=None,
        strategy=None,
        metric_kwargs=None,
    ):
        if (
            isinstance(metric_kwargs, dict) and
            (metric_kwargs.keys() - {"X_norm_squared", "Y_norm_squared"})
        ):
            warnings.warn(
                f"Some metric_kwargs have been passed ({metric_kwargs}) but aren't "
                f"usable for this case (EuclideanArgKmin64) and will be ignored.",
                UserWarning,
                stacklevel=3,
            )

        super().__init__(
            # The datasets pair here is used for exact distances computations
            datasets_pair=DatasetsPair{{name_suffix}}.get_for(X, Y, metric="euclidean"),
            chunk_size=chunk_size,
            strategy=strategy,
            k=k,
        )
        cdef:
            intp_t dist_middle_terms_chunks_size = self.Y_n_samples_chunk * self.X_n_samples_chunk

        self.middle_term_computer = MiddleTermComputer{{name_suffix}}.get_for(
            X,
            Y,
            self.effective_n_threads,
            self.chunks_n_threads,
            dist_middle_terms_chunks_size,
            n_features=X.shape[1],
            chunk_size=self.chunk_size,
        )

        if metric_kwargs is not None and "Y_norm_squared" in metric_kwargs:
            self.Y_norm_squared = check_array(
                metric_kwargs.pop("Y_norm_squared"),
                ensure_2d=False,
                input_name="Y_norm_squared",
                dtype=np.float64,
            )
        else:
            self.Y_norm_squared = _sqeuclidean_row_norms{{name_suffix}}(
                Y,
                self.effective_n_threads,
            )

        if metric_kwargs is not None and "X_norm_squared" in metric_kwargs:
            self.X_norm_squared = check_array(
                metric_kwargs.pop("X_norm_squared"),
                ensure_2d=False,
                input_name="X_norm_squared",
                dtype=np.float64,
            )
        else:
            # Do not recompute norms if datasets are identical.
            self.X_norm_squared = (
                self.Y_norm_squared if X is Y else
                _sqeuclidean_row_norms{{name_suffix}}(
                    X,
                    self.effective_n_threads,
                )
            )

        self.use_squared_distances = use_squared_distances

    @final
    cdef void compute_exact_distances(self) noexcept nogil:
        if not self.use_squared_distances:
            ArgKmin{{name_suffix}}.compute_exact_distances(self)

    @final
    cdef void _parallel_on_X_parallel_init(
        self,
        intp_t thread_num,
    ) noexcept nogil:
        ArgKmin{{name_suffix}}._parallel_on_X_parallel_init(self, thread_num)
        self.middle_term_computer._parallel_on_X_parallel_init(thread_num)

    @final
    cdef void _parallel_on_X_init_chunk(
        self,
        intp_t thread_num,
        intp_t X_start,
        intp_t X_end,
    ) noexcept nogil:
        ArgKmin{{name_suffix}}._parallel_on_X_init_chunk(self, thread_num, X_start, X_end)
        self.middle_term_computer._parallel_on_X_init_chunk(thread_num, X_start, X_end)

    @final
    cdef void _parallel_on_X_pre_compute_and_reduce_distances_on_chunks(
        self,
        intp_t X_start,
        intp_t X_end,
        intp_t Y_start,
        intp_t Y_end,
        intp_t thread_num,
    ) noexcept nogil:
        ArgKmin{{name_suffix}}._parallel_on_X_pre_compute_and_reduce_distances_on_chunks(
            self,
            X_start, X_end,
            Y_start, Y_end,
            thread_num,
        )
        self.middle_term_computer._parallel_on_X_pre_compute_and_reduce_distances_on_chunks(
            X_start, X_end, Y_start, Y_end, thread_num,
        )

    @final
    cdef void _parallel_on_Y_init(
        self,
    ) noexcept nogil:
        ArgKmin{{name_suffix}}._parallel_on_Y_init(self)
        self.middle_term_computer._parallel_on_Y_init()

    @final
    cdef void _parallel_on_Y_parallel_init(
        self,
        intp_t thread_num,
        intp_t X_start,
        intp_t X_end,
    ) noexcept nogil:
        ArgKmin{{name_suffix}}._parallel_on_Y_parallel_init(self, thread_num, X_start, X_end)
        self.middle_term_computer._parallel_on_Y_parallel_init(thread_num, X_start, X_end)

    @final
    cdef void _parallel_on_Y_pre_compute_and_reduce_distances_on_chunks(
        self,
        intp_t X_start,
        intp_t X_end,
        intp_t Y_start,
        intp_t Y_end,
        intp_t thread_num,
    ) noexcept nogil:
        ArgKmin{{name_suffix}}._parallel_on_Y_pre_compute_and_reduce_distances_on_chunks(
            self,
            X_start, X_end,
            Y_start, Y_end,
            thread_num,
        )
        self.middle_term_computer._parallel_on_Y_pre_compute_and_reduce_distances_on_chunks(
            X_start, X_end, Y_start, Y_end, thread_num
        )

    @final
    cdef void _compute_and_reduce_distances_on_chunks(
        self,
        intp_t X_start,
        intp_t X_end,
        intp_t Y_start,
        intp_t Y_end,
        intp_t thread_num,
    ) noexcept nogil:
        cdef:
            intp_t i, j
            float64_t sqeuclidean_dist_i_j
            intp_t n_X = X_end - X_start
            intp_t n_Y = Y_end - Y_start
            float64_t * dist_middle_terms = self.middle_term_computer._compute_dist_middle_terms(
                X_start, X_end, Y_start, Y_end, thread_num
            )
            float64_t * heaps_r_distances = self.heaps_r_distances_chunks[thread_num]
            intp_t * heaps_indices = self.heaps_indices_chunks[thread_num]

        # Pushing the distance and their associated indices on heaps
        # which keep tracks of the argkmin.
        for i in range(n_X):
            for j in range(n_Y):
                sqeuclidean_dist_i_j = (
                    self.X_norm_squared[i + X_start] +
                    dist_middle_terms[i * n_Y + j] +
                    self.Y_norm_squared[j + Y_start]
                )

                # Catastrophic cancellation might cause -0. to be present,
                # e.g. when computing d(x_i, y_i) when X is Y.
                sqeuclidean_dist_i_j = max(0., sqeuclidean_dist_i_j)

                heap_push(
                    values=heaps_r_distances + i * self.k,
                    indices=heaps_indices + i * self.k,
                    size=self.k,
                    val=sqeuclidean_dist_i_j,
                    val_idx=j + Y_start,
                )

{{endfor}}
