From e6717d5a872e236f90977519a76cb35446ab0d5d Mon Sep 17 00:00:00 2001 From: Gustaf Rydholm Date: Thu, 3 Feb 2022 21:39:17 +0100 Subject: chore: remove axial attention chore: remove axial attention --- text_recognizer/networks/conv_transformer.py | 4 ---- 1 file changed, 4 deletions(-) (limited to 'text_recognizer/networks/conv_transformer.py') diff --git a/text_recognizer/networks/conv_transformer.py b/text_recognizer/networks/conv_transformer.py index a068ea3..5b29362 100644 --- a/text_recognizer/networks/conv_transformer.py +++ b/text_recognizer/networks/conv_transformer.py @@ -6,7 +6,6 @@ from loguru import logger as log from torch import nn, Tensor from text_recognizer.networks.base import BaseTransformer -from text_recognizer.networks.transformer.axial_attention.encoder import AxialEncoder from text_recognizer.networks.transformer.decoder import Decoder from text_recognizer.networks.transformer.embeddings.axial import ( AxialPositionalEmbedding, @@ -24,7 +23,6 @@ class ConvTransformer(BaseTransformer): pad_index: Tensor, encoder: Type[nn.Module], decoder: Decoder, - axial_encoder: Optional[AxialEncoder], pixel_pos_embedding: AxialPositionalEmbedding, token_pos_embedding: Optional[Type[nn.Module]] = None, ) -> None: @@ -39,7 +37,6 @@ class ConvTransformer(BaseTransformer): ) self.pixel_pos_embedding = pixel_pos_embedding - self.axial_encoder = axial_encoder # Latent projector for down sampling number of filters and 2d # positional encoding. @@ -79,7 +76,6 @@ class ConvTransformer(BaseTransformer): z = self.encoder(x) z = self.conv(z) z = self.pixel_pos_embedding(z) - z = self.axial_encoder(z) if self.axial_encoder is not None else z z = z.flatten(start_dim=2) # Permute tensor from [B, E, Ho * Wo] to [B, Sx, E] -- cgit v1.2.3-70-g09d2