diff options
| author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2022-09-27 00:10:26 +0200 |
|---|---|---|
| committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2022-09-27 00:10:26 +0200 |
| commit | 0540237d794ab2071764dc74e4d3bb52f5bf44be (patch) | |
| tree | dad3469f843da16716871d0b9805bf0301aa6cfe /text_recognizer/models/base.py | |
| parent | bf680dce6bc7dcadd20923a193fc9ab8fbd0a0c6 (diff) | |
Update metrics
Diffstat (limited to 'text_recognizer/models/base.py')
| -rw-r--r-- | text_recognizer/models/base.py | 9 |
1 files changed, 5 insertions, 4 deletions
diff --git a/text_recognizer/models/base.py b/text_recognizer/models/base.py index f917635..bb4e695 100644 --- a/text_recognizer/models/base.py +++ b/text_recognizer/models/base.py @@ -6,7 +6,7 @@ import torch from loguru import logger as log from omegaconf import DictConfig from pytorch_lightning import LightningModule -from torch import Tensor, nn +from torch import nn, Tensor from torchmetrics import Accuracy from text_recognizer.data.mappings import EmnistMapping @@ -22,6 +22,7 @@ class LitBase(LightningModule): optimizer_config: DictConfig, lr_scheduler_config: Optional[DictConfig], mapping: EmnistMapping, + ignore_index: Optional[int] = None, ) -> None: super().__init__() @@ -32,9 +33,9 @@ class LitBase(LightningModule): self.mapping = mapping # Placeholders - self.train_acc = Accuracy(mdmc_reduce="samplewise") - self.val_acc = Accuracy(mdmc_reduce="samplewise") - self.test_acc = Accuracy(mdmc_reduce="samplewise") + self.train_acc = Accuracy(mdmc_reduce="samplewise", ignore_index=ignore_index) + self.val_acc = Accuracy(mdmc_reduce="samplewise", ignore_index=ignore_index) + self.test_acc = Accuracy(mdmc_reduce="samplewise", ignore_index=ignore_index) def optimizer_zero_grad( self, |