From eb5b206f7e1b08435378d2a02395307be55ee6f1 Mon Sep 17 00:00:00 2001 From: Gustaf Rydholm Date: Tue, 6 Jul 2021 17:42:53 +0200 Subject: Refactoring data with attrs and refactor conf for hydra --- text_recognizer/models/transformer.py | 26 ++++++++------------------ 1 file changed, 8 insertions(+), 18 deletions(-) (limited to 'text_recognizer/models/transformer.py') diff --git a/text_recognizer/models/transformer.py b/text_recognizer/models/transformer.py index ea54d83..8c9fe8a 100644 --- a/text_recognizer/models/transformer.py +++ b/text_recognizer/models/transformer.py @@ -2,35 +2,24 @@ from typing import Dict, List, Optional, Union, Tuple, Type import attr +import hydra from omegaconf import DictConfig from torch import nn, Tensor from text_recognizer.data.emnist import emnist_mapping from text_recognizer.data.mappings import AbstractMapping from text_recognizer.models.metrics import CharacterErrorRate -from text_recognizer.models.base import LitBaseModel +from text_recognizer.models.base import BaseLitModel -@attr.s -class TransformerLitModel(LitBaseModel): +@attr.s(auto_attribs=True) +class TransformerLitModel(BaseLitModel): """A PyTorch Lightning model for transformer networks.""" - network: Type[nn.Module] = attr.ib() - criterion_config: DictConfig = attr.ib(converter=DictConfig) - optimizer_config: DictConfig = attr.ib(converter=DictConfig) - lr_scheduler_config: DictConfig = attr.ib(converter=DictConfig) - monitor: str = attr.ib() - mapping: Type[AbstractMapping] = attr.ib() + mapping_config: DictConfig = attr.ib(converter=DictConfig) def __attrs_post_init__(self) -> None: - super().__init__( - network=self.network, - optimizer_config=self.optimizer_config, - lr_scheduler_config=self.lr_scheduler_config, - criterion_config=self.criterion_config, - monitor=self.monitor, - ) - self.mapping, ignore_tokens = self.configure_mapping(mapping) + self.mapping, ignore_tokens = self._configure_mapping() self.val_cer = CharacterErrorRate(ignore_tokens) self.test_cer = CharacterErrorRate(ignore_tokens) @@ -39,9 +28,10 @@ class TransformerLitModel(LitBaseModel): return self.network.predict(data) @staticmethod - def configure_mapping(mapping: Optional[List[str]]) -> Tuple[List[str], List[int]]: + def _configure_mapping() -> Tuple[Type[AbstractMapping], List[int]]: """Configure mapping.""" # TODO: Fix me!!! + # Load config with hydra mapping, inverse_mapping, _ = emnist_mapping(["\n"]) start_index = inverse_mapping[""] end_index = inverse_mapping[""] -- cgit v1.2.3-70-g09d2