From 442eac315e4b8be19adab80fb7332d29f68c077c Mon Sep 17 00:00:00 2001 From: Gustaf Rydholm Date: Sun, 27 Jun 2021 20:25:25 +0200 Subject: Fixed bug in word pieces --- text_recognizer/data/mappings.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) (limited to 'text_recognizer/data/mappings.py') diff --git a/text_recognizer/data/mappings.py b/text_recognizer/data/mappings.py index 190febe..0d778b2 100644 --- a/text_recognizer/data/mappings.py +++ b/text_recognizer/data/mappings.py @@ -125,6 +125,9 @@ class WordPieceMapping(EmnistMapping): special_tokens, ) + def __len__(self) -> int: + return len(self.wordpiece_processor.tokens) + def get_token(self, index: Union[int, Tensor]) -> str: if (index := int(index)) <= self.wordpiece_processor.num_tokens: return self.wordpiece_processor.tokens[index] @@ -132,7 +135,7 @@ class WordPieceMapping(EmnistMapping): def get_index(self, token: str) -> Tensor: if token in self.wordpiece_processor.tokens: - return torch.LongTensor(self.wordpiece_processor.tokens_to_index[token]) + return torch.LongTensor([self.wordpiece_processor.tokens_to_index[token]]) raise KeyError(f"Token ({token}) not found in inverse mapping.") def get_text(self, indices: Union[List[int], Tensor]) -> str: @@ -147,3 +150,8 @@ class WordPieceMapping(EmnistMapping): text = "".join([self.mapping[i] for i in x]) text = text.lower().replace(" ", "▁") return torch.LongTensor(self.wordpiece_processor.to_index(text)) + + def __getitem__(self, x: Union[str, int, Tensor]) -> Union[str, Tensor]: + if isinstance(x, str): + return self.get_index(x) + return self.get_token(x) -- cgit v1.2.3-70-g09d2