> from torchmetrics.wrappers import ClasswiseWrapper >>> from torchmetrics.classification import MulticlassAccuracy >>> metric = ClasswiseWrapper(MulticlassAccuracy(num_classes=3, average=None)) >>> values = [ ] >>> for _ in range(3): ... values.append(metric(torch.randint(3, (20,)), torch.randint(3, (20,)))) >>> fig_, ax_ = metric.plot(values) )