From 38202e9c6c1155d96ee0f6e9f337022ee4eeb7e3 Mon Sep 17 00:00:00 2001 From: Gustaf Rydholm Date: Mon, 5 Apr 2021 23:12:20 +0200 Subject: Add OmegaConf for configs --- text_recognizer/models/transformer.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) (limited to 'text_recognizer/models/transformer.py') diff --git a/text_recognizer/models/transformer.py b/text_recognizer/models/transformer.py index 285b715..3625ab2 100644 --- a/text_recognizer/models/transformer.py +++ b/text_recognizer/models/transformer.py @@ -1,6 +1,7 @@ """PyTorch Lightning model for base Transformers.""" -from typing import Dict, List, Optional, Tuple +from typing import Dict, List, Optional, Union, Tuple +from omegaconf import OmegaConf import pytorch_lightning as pl import torch from torch import nn @@ -18,15 +19,15 @@ class LitTransformerModel(LitBaseModel): def __init__( self, - network_args: Dict, - optimizer_args: Dict, - lr_scheduler_args: Dict, - criterion_args: Dict, + network: Type[nn,Module], + optimizer: Union[OmegaConf, Dict], + lr_scheduler: Union[OmegaConf, Dict], + criterion: Union[OmegaConf, Dict], monitor: str = "val_loss", mapping: Optional[List[str]] = None, ) -> None: super().__init__( - network_args, optimizer_args, lr_scheduler_args, criterion_args, monitor + network, optimizer, lr_scheduler, criterion, monitor ) self.mapping, ignore_tokens = self.configure_mapping(mapping) @@ -40,6 +41,7 @@ class LitTransformerModel(LitBaseModel): @staticmethod def configure_mapping(mapping: Optional[List[str]]) -> Tuple[List[str], List[int]]: """Configure mapping.""" + # TODO: Fix me!!! mapping, inverse_mapping, _ = emnist_mapping() start_index = inverse_mapping[""] end_index = inverse_mapping[""] -- cgit v1.2.3-70-g09d2