From 65d5f6c694e73792e40ed693a1381a792da8d277 Mon Sep 17 00:00:00 2001 From: Gustaf Rydholm Date: Tue, 3 Aug 2021 19:14:16 +0200 Subject: Fix bugs in converting text in mappings, add missing word_piece arg in datamodule --- text_recognizer/data/word_piece_mapping.py | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) (limited to 'text_recognizer/data/word_piece_mapping.py') diff --git a/text_recognizer/data/word_piece_mapping.py b/text_recognizer/data/word_piece_mapping.py index 59488c3..2f650cd 100644 --- a/text_recognizer/data/word_piece_mapping.py +++ b/text_recognizer/data/word_piece_mapping.py @@ -75,7 +75,7 @@ class WordPieceMapping(EmnistMapping): def get_text(self, indices: Union[List[int], Tensor]) -> str: if isinstance(indices, Tensor): indices = indices.tolist() - return self.wordpiece_processor.to_text(indices).replace(" ", "▁") + return self.wordpiece_processor.to_text(indices) def get_indices(self, text: str) -> Tensor: return self.wordpiece_processor.to_index(text) @@ -85,9 +85,5 @@ class WordPieceMapping(EmnistMapping): text = text.lower().replace(" ", "▁") return torch.LongTensor(self.wordpiece_processor.to_index(text)) - def __getitem__(self, x: Union[str, int, List[int], Tensor]) -> Union[str, Tensor]: - if isinstance(x, int): - x = [x] - if isinstance(x, str): - return self.get_indices(x) - return self.get_text(x) + def __getitem__(self, x: Union[int, Tensor]) -> str: + return self.get_token(x) -- cgit v1.2.3-70-g09d2