From 49ca6ade1a19f7f9c702171537fe4be0dfcda66d Mon Sep 17 00:00:00 2001 From: Gustaf Rydholm Date: Fri, 25 Aug 2023 23:19:14 +0200 Subject: Rename and add flash atten --- text_recognizer/networks/conv_transformer.py | 49 ---------------------------- 1 file changed, 49 deletions(-) delete mode 100644 text_recognizer/networks/conv_transformer.py (limited to 'text_recognizer/networks/conv_transformer.py') diff --git a/text_recognizer/networks/conv_transformer.py b/text_recognizer/networks/conv_transformer.py deleted file mode 100644 index d36162a..0000000 --- a/text_recognizer/networks/conv_transformer.py +++ /dev/null @@ -1,49 +0,0 @@ -"""Base network module.""" -from typing import Type - -from torch import Tensor, nn - -from text_recognizer.networks.transformer.decoder import Decoder - - -class ConvTransformer(nn.Module): - """Base transformer network.""" - - def __init__( - self, - encoder: Type[nn.Module], - decoder: Decoder, - ) -> None: - super().__init__() - self.encoder = encoder - self.decoder = decoder - - def encode(self, img: Tensor) -> Tensor: - """Encodes images to latent representation.""" - return self.encoder(img) - - def decode(self, tokens: Tensor, img_features: Tensor) -> Tensor: - """Decodes latent images embedding into characters.""" - return self.decoder(tokens, img_features) - - def forward(self, img: Tensor, tokens: Tensor) -> Tensor: - """Encodes images into token logtis. - - Args: - img (Tensor): Input image(s). - tokens (Tensor): token embeddings. - - Shapes: - - img: :math: `(B, 1, H, W)` - - tokens: :math: `(B, Sy)` - - logits: :math: `(B, Sy, C)` - - where B is the batch size, H is the image height, W is the image - width, Sy the output length, and C is the number of classes. - - Returns: - Tensor: Sequence of logits. - """ - img_features = self.encode(img) - logits = self.decode(tokens, img_features) - return logits -- cgit v1.2.3-70-g09d2