summaryrefslogtreecommitdiff
path: root/text_recognizer/data/transforms.py
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2021-10-10 18:04:50 +0200
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2021-10-10 18:04:50 +0200
commit8291a87c64f9a5f18caec82201bea15579b49730 (patch)
tree1c8bb3e07a3bd06086e182dd320f8408829ba81c /text_recognizer/data/transforms.py
parent30e3ae483c846418b04ed48f014a4af2cf9a0771 (diff)
Move data utils to submodules
Diffstat (limited to 'text_recognizer/data/transforms.py')
-rw-r--r--text_recognizer/data/transforms.py49
1 files changed, 0 insertions, 49 deletions
diff --git a/text_recognizer/data/transforms.py b/text_recognizer/data/transforms.py
deleted file mode 100644
index 7f3e0d1..0000000
--- a/text_recognizer/data/transforms.py
+++ /dev/null
@@ -1,49 +0,0 @@
-"""Transforms for PyTorch datasets."""
-from pathlib import Path
-from typing import Optional, Union, Type, Set
-
-import torch
-from torch import Tensor
-
-from text_recognizer.data.base_mapping import AbstractMapping
-from text_recognizer.data.word_piece_mapping import WordPieceMapping
-
-
-class WordPiece:
- """Converts EMNIST indices to Word Piece indices."""
-
- def __init__(
- self,
- num_features: int = 1000,
- tokens: str = "iamdb_1kwp_tokens_1000.txt",
- lexicon: str = "iamdb_1kwp_lex_1000.txt",
- data_dir: Optional[Union[str, Path]] = None,
- use_words: bool = False,
- prepend_wordsep: bool = False,
- special_tokens: Set[str] = {"<s>", "<e>", "<p>"},
- extra_symbols: Optional[Set[str]] = {"\n",},
- max_len: int = 451,
- ) -> None:
- self.mapping = WordPieceMapping(
- data_dir=data_dir,
- num_features=num_features,
- tokens=tokens,
- lexicon=lexicon,
- use_words=use_words,
- prepend_wordsep=prepend_wordsep,
- special_tokens=special_tokens,
- extra_symbols=extra_symbols,
- )
- self.max_len = max_len
-
- def __call__(self, x: Tensor) -> Tensor:
- """Converts Emnist target tensor to Word piece target tensor."""
- y = self.mapping.emnist_to_wordpiece_indices(x)
- if len(y) < self.max_len:
- pad_len = self.max_len - len(y)
- y = torch.cat(
- (y, torch.LongTensor([self.mapping.get_index("<p>")] * pad_len))
- )
- else:
- y = y[: self.max_len]
- return y