lassification": classification_preds, "Regression": regression_preds} >>> >>> metrics = MultitaskWrapper({ ... "Classification": BinaryAccuracy(), ... "Regression": MeanSquaredError() ... }) >>> metrics.update(preds, targets) >>> value = metrics.compute() >>> fig_, ax_ = metrics.plot(value) .. plot:: :scale: 75 >>> # Example plotting multiple values >>> import torch >>> from torchmetrics.wrappers import MultitaskWrapper >>> from torchmetrics.regression import MeanSquaredError >>> from torchmetrics.classification import BinaryAccuracy >>> >>> classification_target = torch.tensor([0, 1, 0]) >>> regression_target = torch.tensor([2.5, 5.0, 4.0]) >>> targets = {"Classification": classification_target, "Regression": regression_target} >>> >>> classification_preds = torch.tensor([0, 0, 1]) >>> regression_preds = torch.tensor([3.0, 5.0, 2.5]) >>> preds = {"Classification": classification_preds, "Regression": regression_preds} >>> >>> metrics = MultitaskWrapper({ ... "Classification": BinaryAccuracy(), ... "Regression": MeanSquaredError() ... }) >>> values = [] >>> for _ in range(10): ... values.append(metrics(preds, targets)) >>> fig_, ax_ = metrics.plot(values) Nz>Expected argument `axes` to be a Sequence. Found type(axes) = c