diff options
| author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2022-09-27 23:11:06 +0200 |
|---|---|---|
| committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2022-09-27 23:11:06 +0200 |
| commit | 9c7dbb9ca70858b870f74ecf595d3169f0cbc711 (patch) | |
| tree | c342e2c004bb75571a380ef2805049a8fcec3fcc /text_recognizer/models/base.py | |
| parent | 9b8e14d89f0ef2508ed11f994f73af624155fe1d (diff) | |
Rename mapping to tokenizer
Diffstat (limited to 'text_recognizer/models/base.py')
| -rw-r--r-- | text_recognizer/models/base.py | 9 |
1 files changed, 4 insertions, 5 deletions
diff --git a/text_recognizer/models/base.py b/text_recognizer/models/base.py index bb4e695..f8f4b40 100644 --- a/text_recognizer/models/base.py +++ b/text_recognizer/models/base.py @@ -9,7 +9,7 @@ from pytorch_lightning import LightningModule from torch import nn, Tensor from torchmetrics import Accuracy -from text_recognizer.data.mappings import EmnistMapping +from text_recognizer.data.tokenizer import Tokenizer class LitBase(LightningModule): @@ -21,8 +21,7 @@ class LitBase(LightningModule): loss_fn: Type[nn.Module], optimizer_config: DictConfig, lr_scheduler_config: Optional[DictConfig], - mapping: EmnistMapping, - ignore_index: Optional[int] = None, + tokenizer: Tokenizer, ) -> None: super().__init__() @@ -30,8 +29,8 @@ class LitBase(LightningModule): self.loss_fn = loss_fn self.optimizer_config = optimizer_config self.lr_scheduler_config = lr_scheduler_config - self.mapping = mapping - + self.tokenizer = tokenizer + ignore_index = int(self.tokenizer.get_value("<p>")) # Placeholders self.train_acc = Accuracy(mdmc_reduce="samplewise", ignore_index=ignore_index) self.val_acc = Accuracy(mdmc_reduce="samplewise", ignore_index=ignore_index) |