diff options
| author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2022-06-01 23:10:12 +0200 |
|---|---|---|
| committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2022-06-01 23:10:12 +0200 |
| commit | db86cef2d308f58325278061c6aa177a535e7e03 (patch) | |
| tree | a013fa85816337269f9cdc5a8992813fa62d299d /text_recognizer/models/transformer.py | |
| parent | b980a281712a5b1ee7ee5bd8f5d4762cd91a070b (diff) | |
Replace attr with attrs
Diffstat (limited to 'text_recognizer/models/transformer.py')
| -rw-r--r-- | text_recognizer/models/transformer.py | 24 |
1 files changed, 12 insertions, 12 deletions
diff --git a/text_recognizer/models/transformer.py b/text_recognizer/models/transformer.py index 7272f46..c5120fe 100644 --- a/text_recognizer/models/transformer.py +++ b/text_recognizer/models/transformer.py @@ -1,7 +1,7 @@ """PyTorch Lightning model for base Transformers.""" from typing import Set, Tuple -import attr +from attrs import define, field import torch from torch import Tensor @@ -9,22 +9,22 @@ from text_recognizer.models.base import BaseLitModel from text_recognizer.models.metrics import CharacterErrorRate -@attr.s(auto_attribs=True, eq=False) +@define(auto_attribs=True, eq=False) class TransformerLitModel(BaseLitModel): """A PyTorch Lightning model for transformer networks.""" - max_output_len: int = attr.ib(default=451) - start_token: str = attr.ib(default="<s>") - end_token: str = attr.ib(default="<e>") - pad_token: str = attr.ib(default="<p>") + max_output_len: int = field(default=451) + start_token: str = field(default="<s>") + end_token: str = field(default="<e>") + pad_token: str = field(default="<p>") - start_index: int = attr.ib(init=False) - end_index: int = attr.ib(init=False) - pad_index: int = attr.ib(init=False) + start_index: int = field(init=False) + end_index: int = field(init=False) + pad_index: int = field(init=False) - ignore_indices: Set[Tensor] = attr.ib(init=False) - val_cer: CharacterErrorRate = attr.ib(init=False) - test_cer: CharacterErrorRate = attr.ib(init=False) + ignore_indices: Set[Tensor] = field(init=False) + val_cer: CharacterErrorRate = field(init=False) + test_cer: CharacterErrorRate = field(init=False) def __attrs_post_init__(self) -> None: """Post init configuration.""" |