From 8c4a0c2603975cfc63f4e4019386e001387c42c9 Mon Sep 17 00:00:00 2001 From: Gustaf Rydholm Date: Tue, 4 Oct 2022 22:08:38 +0200 Subject: Add greedy decoder --- text_recognizer/models/greedy_decoder.py | 51 ++++++++++++++++++++++++++++++++ 1 file changed, 51 insertions(+) create mode 100644 text_recognizer/models/greedy_decoder.py (limited to 'text_recognizer/models/greedy_decoder.py') diff --git a/text_recognizer/models/greedy_decoder.py b/text_recognizer/models/greedy_decoder.py new file mode 100644 index 0000000..9d2f192 --- /dev/null +++ b/text_recognizer/models/greedy_decoder.py @@ -0,0 +1,51 @@ +"""Greedy decoder.""" +from typing import Type +from text_recognizer.data.tokenizer import Tokenizer +import torch +from torch import nn, Tensor + + +class GreedyDecoder: + def __init__( + self, + network: Type[nn.Module], + tokenizer: Tokenizer, + max_output_len: int = 682, + ) -> None: + self.network = network + self.start_index = tokenizer.start_index + self.end_index = tokenizer.end_index + self.pad_index = tokenizer.pad_index + self.max_output_len = max_output_len + + def __call__(self, x: Tensor) -> Tensor: + bsz = x.shape[0] + + # Encode image(s) to latent vectors. + img_features = self.network.encode(x) + + # Create a placeholder matrix for storing outputs from the network + indecies = torch.ones((bsz, self.max_output_len), dtype=torch.long).to(x.device) + indecies[:, 0] = self.start_index + + for Sy in range(1, self.max_output_len): + tokens = indecies[:, :Sy] # (B, Sy) + logits = self.network.decode(tokens, img_features) # (B, C, Sy) + indecies_ = torch.argmax(logits, dim=1) # (B, Sy) + indecies[:, Sy : Sy + 1] = indecies_[:, -1:] + + # Early stopping of prediction loop if token is end or padding token. + if ( + (indecies[:, Sy - 1] == self.end_index) + | (indecies[:, Sy - 1] == self.pad_index) + ).all(): + break + + # Set all tokens after end token to pad token. + for Sy in range(1, self.max_output_len): + idx = (indecies[:, Sy - 1] == self.end_index) | ( + indecies[:, Sy - 1] == self.pad_index + ) + indecies[idx, Sy] = self.pad_index + + return indecies -- cgit v1.2.3-70-g09d2