ulticlassConfusionMatrix): ... def update(self, preds, target): ... super().update(preds, target) ... # by construction make future states dependent on prior states ... if self.confmat.sum() > 20: ... self.reset() >>> check_forward_full_state_property( ... MyMetric, ... init_args = {'num_classes': 3}, ... input_args = {'preds': torch.randint(3, (10,)), 'target': torch.randint(3, (10,))}, ... ) Recommended setting `full_state_update=True` c