diff options
| author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2022-06-05 23:39:11 +0200 |
|---|---|---|
| committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2022-06-05 23:39:11 +0200 |
| commit | 65df6a72b002c4b23d6f2eb545839e157f7f2aa0 (patch) | |
| tree | d78df1d7143dc9ff9e29afd4fd6bc7490bc79418 /text_recognizer/models/transformer.py | |
| parent | 8bc4b4cab00a2777a748c10fca9b3ee01e32277c (diff) | |
Remove attrs
Diffstat (limited to 'text_recognizer/models/transformer.py')
| -rw-r--r-- | text_recognizer/models/transformer.py | 29 |
1 files changed, 12 insertions, 17 deletions
diff --git a/text_recognizer/models/transformer.py b/text_recognizer/models/transformer.py index c5120fe..9537dd9 100644 --- a/text_recognizer/models/transformer.py +++ b/text_recognizer/models/transformer.py @@ -1,7 +1,6 @@ """PyTorch Lightning model for base Transformers.""" from typing import Set, Tuple -from attrs import define, field import torch from torch import Tensor @@ -9,25 +8,21 @@ from text_recognizer.models.base import BaseLitModel from text_recognizer.models.metrics import CharacterErrorRate -@define(auto_attribs=True, eq=False) class TransformerLitModel(BaseLitModel): """A PyTorch Lightning model for transformer networks.""" - max_output_len: int = field(default=451) - start_token: str = field(default="<s>") - end_token: str = field(default="<e>") - pad_token: str = field(default="<p>") - - start_index: int = field(init=False) - end_index: int = field(init=False) - pad_index: int = field(init=False) - - ignore_indices: Set[Tensor] = field(init=False) - val_cer: CharacterErrorRate = field(init=False) - test_cer: CharacterErrorRate = field(init=False) - - def __attrs_post_init__(self) -> None: - """Post init configuration.""" + def __init__( + self, + max_output_len: int = 451, + start_token: str = "<s>", + end_token: str = "<e>", + pad_token: str = "<p>", + ) -> None: + super().__init__() + self.max_output_len = max_output_len + self.start_token = start_token + self.end_token = end_token + self.pad_token = pad_token self.start_index = int(self.mapping.get_index(self.start_token)) self.end_index = int(self.mapping.get_index(self.end_token)) self.pad_index = int(self.mapping.get_index(self.pad_token)) |