called and expected to return a batch of metric tensors ``(batch,)``; if `mode`==`'permutation-wise'`, then ``metric_func(preds[:, p, ...], target[:, :, ...])`` is called, where `p` is one possible permutation, e.g. [0,1] or [1,0] for 2-speaker case, and expected to return a batch of metric tensors ``(batch,)``; mode: can be `'speaker-wise'` or `'permutation-wise'`. eval_func: the function to find the best permutation, can be ``'min'`` or ``'max'``, i.e. the smaller the better or the larger the better. kwargs: Additional args for metric_func Returns: Tuple of two float tensors. First tensor with shape ``(batch,)`` contains the best metric value for each sample and second tensor with shape ``(batch,)`` contains the best permutation. Example: >>> from torchmetrics.functional.audio import scale_invariant_signal_distortion_ratio >>> # [batch, spk, time] >>> preds = torch.tensor([[[-0.0579, 0.3560, -0.9604], [-0.1719, 0.3205, 0.2951]]]) >>> target = torch.tensor([[[ 1.0958, -0.1648, 0.5228], [-0.4100, 1.1942, -0.5103]]]) >>> best_metric, best_perm = permutation_invariant_training( ... preds, target, scale_invariant_signal_distortion_ratio, ... mode="speaker-wise", eval_func="max") >>> best_metric tensor([-5.1091]) >>> best_perm tensor([[0, 1]]) >>> pit_permutate(preds, best_perm) tensor([[[-0.0579, 0.3560, -0.9604], [-0.1719, 0.3205, 0.2951]]]) r