From 9c7dbb9ca70858b870f74ecf595d3169f0cbc711 Mon Sep 17 00:00:00 2001 From: Gustaf Rydholm Date: Tue, 27 Sep 2022 23:11:06 +0200 Subject: Rename mapping to tokenizer --- text_recognizer/models/base.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) (limited to 'text_recognizer/models/base.py') diff --git a/text_recognizer/models/base.py b/text_recognizer/models/base.py index bb4e695..f8f4b40 100644 --- a/text_recognizer/models/base.py +++ b/text_recognizer/models/base.py @@ -9,7 +9,7 @@ from pytorch_lightning import LightningModule from torch import nn, Tensor from torchmetrics import Accuracy -from text_recognizer.data.mappings import EmnistMapping +from text_recognizer.data.tokenizer import Tokenizer class LitBase(LightningModule): @@ -21,8 +21,7 @@ class LitBase(LightningModule): loss_fn: Type[nn.Module], optimizer_config: DictConfig, lr_scheduler_config: Optional[DictConfig], - mapping: EmnistMapping, - ignore_index: Optional[int] = None, + tokenizer: Tokenizer, ) -> None: super().__init__() @@ -30,8 +29,8 @@ class LitBase(LightningModule): self.loss_fn = loss_fn self.optimizer_config = optimizer_config self.lr_scheduler_config = lr_scheduler_config - self.mapping = mapping - + self.tokenizer = tokenizer + ignore_index = int(self.tokenizer.get_value("

")) # Placeholders self.train_acc = Accuracy(mdmc_reduce="samplewise", ignore_index=ignore_index) self.val_acc = Accuracy(mdmc_reduce="samplewise", ignore_index=ignore_index) -- cgit v1.2.3-70-g09d2